diff options
Diffstat (limited to 'src')
46 files changed, 4538 insertions, 3589 deletions
diff --git a/src/werkzeug/__init__.py b/src/werkzeug/__init__.py index f8d0d447..5a0c4e9e 100644 --- a/src/werkzeug/__init__.py +++ b/src/werkzeug/__init__.py @@ -14,12 +14,10 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -from types import ModuleType import sys +from types import ModuleType -from werkzeug._compat import iteritems - -__version__ = '0.15.dev' +__version__ = "0.15.dev" # This import magic raises concerns quite often which is why the implementation # and motivation is explained here in detail now. @@ -34,82 +32,153 @@ __version__ = '0.15.dev' # werkzeug package when imported from within. Attribute access to the werkzeug # module will then lazily import from the modules that implement the objects. - # import mapping to objects in other modules all_by_module = { - 'werkzeug.debug': ['DebuggedApplication'], - 'werkzeug.local': ['Local', 'LocalManager', 'LocalProxy', 'LocalStack', - 'release_local'], - 'werkzeug.serving': ['run_simple'], - 'werkzeug.test': ['Client', 'EnvironBuilder', 'create_environ', - 'run_wsgi_app'], - 'werkzeug.testapp': ['test_app'], - 'werkzeug.exceptions': ['abort', 'Aborter'], - 'werkzeug.urls': ['url_decode', 'url_encode', 'url_quote', - 'url_quote_plus', 'url_unquote', 'url_unquote_plus', - 'url_fix', 'Href', 'iri_to_uri', 'uri_to_iri'], - 'werkzeug.formparser': ['parse_form_data'], - 'werkzeug.utils': ['escape', 'environ_property', 'append_slash_redirect', - 'redirect', 'cached_property', 'import_string', - 'dump_cookie', 'parse_cookie', 'unescape', - 'format_string', 'find_modules', 'header_property', - 'html', 'xhtml', 'HTMLBuilder', 'validate_arguments', - 'ArgumentValidationError', 'bind_arguments', - 'secure_filename'], - 'werkzeug.wsgi': ['get_current_url', 'get_host', 'pop_path_info', - 'peek_path_info', - 'ClosingIterator', 'FileWrapper', - 'make_line_iter', 'LimitedStream', 'responder', - 'wrap_file', 'extract_path_info'], - 'werkzeug.datastructures': ['MultiDict', 'CombinedMultiDict', 'Headers', - 'EnvironHeaders', 'ImmutableList', - 'ImmutableDict', 'ImmutableMultiDict', - 'TypeConversionDict', - 'ImmutableTypeConversionDict', 'Accept', - 'MIMEAccept', 'CharsetAccept', - 'LanguageAccept', 'RequestCacheControl', - 'ResponseCacheControl', 'ETags', 'HeaderSet', - 'WWWAuthenticate', 'Authorization', - 'FileMultiDict', 'CallbackDict', 'FileStorage', - 'OrderedMultiDict', 'ImmutableOrderedMultiDict' - ], - 'werkzeug.useragents': ['UserAgent'], - 'werkzeug.http': ['parse_etags', 'parse_date', 'http_date', 'cookie_date', - 'parse_cache_control_header', 'is_resource_modified', - 'parse_accept_header', 'parse_set_header', 'quote_etag', - 'unquote_etag', 'generate_etag', 'dump_header', - 'parse_list_header', 'parse_dict_header', - 'parse_authorization_header', - 'parse_www_authenticate_header', 'remove_entity_headers', - 'is_entity_header', 'remove_hop_by_hop_headers', - 'parse_options_header', 'dump_options_header', - 'is_hop_by_hop_header', 'unquote_header_value', - 'quote_header_value', 'HTTP_STATUS_CODES'], - 'werkzeug.wrappers': ['BaseResponse', 'BaseRequest', 'Request', 'Response', - 'AcceptMixin', 'ETagRequestMixin', - 'ETagResponseMixin', 'ResponseStreamMixin', - 'CommonResponseDescriptorsMixin', 'UserAgentMixin', - 'AuthorizationMixin', 'WWWAuthenticateMixin', - 'CommonRequestDescriptorsMixin'], + "werkzeug.debug": ["DebuggedApplication"], + "werkzeug.local": [ + "Local", + "LocalManager", + "LocalProxy", + "LocalStack", + "release_local", + ], + "werkzeug.serving": ["run_simple"], + "werkzeug.test": ["Client", "EnvironBuilder", "create_environ", "run_wsgi_app"], + "werkzeug.testapp": ["test_app"], + "werkzeug.exceptions": ["abort", "Aborter"], + "werkzeug.urls": [ + "url_decode", + "url_encode", + "url_quote", + "url_quote_plus", + "url_unquote", + "url_unquote_plus", + "url_fix", + "Href", + "iri_to_uri", + "uri_to_iri", + ], + "werkzeug.formparser": ["parse_form_data"], + "werkzeug.utils": [ + "escape", + "environ_property", + "append_slash_redirect", + "redirect", + "cached_property", + "import_string", + "dump_cookie", + "parse_cookie", + "unescape", + "format_string", + "find_modules", + "header_property", + "html", + "xhtml", + "HTMLBuilder", + "validate_arguments", + "ArgumentValidationError", + "bind_arguments", + "secure_filename", + ], + "werkzeug.wsgi": [ + "get_current_url", + "get_host", + "pop_path_info", + "peek_path_info", + "ClosingIterator", + "FileWrapper", + "make_line_iter", + "LimitedStream", + "responder", + "wrap_file", + "extract_path_info", + ], + "werkzeug.datastructures": [ + "MultiDict", + "CombinedMultiDict", + "Headers", + "EnvironHeaders", + "ImmutableList", + "ImmutableDict", + "ImmutableMultiDict", + "TypeConversionDict", + "ImmutableTypeConversionDict", + "Accept", + "MIMEAccept", + "CharsetAccept", + "LanguageAccept", + "RequestCacheControl", + "ResponseCacheControl", + "ETags", + "HeaderSet", + "WWWAuthenticate", + "Authorization", + "FileMultiDict", + "CallbackDict", + "FileStorage", + "OrderedMultiDict", + "ImmutableOrderedMultiDict", + ], + "werkzeug.useragents": ["UserAgent"], + "werkzeug.http": [ + "parse_etags", + "parse_date", + "http_date", + "cookie_date", + "parse_cache_control_header", + "is_resource_modified", + "parse_accept_header", + "parse_set_header", + "quote_etag", + "unquote_etag", + "generate_etag", + "dump_header", + "parse_list_header", + "parse_dict_header", + "parse_authorization_header", + "parse_www_authenticate_header", + "remove_entity_headers", + "is_entity_header", + "remove_hop_by_hop_headers", + "parse_options_header", + "dump_options_header", + "is_hop_by_hop_header", + "unquote_header_value", + "quote_header_value", + "HTTP_STATUS_CODES", + ], + "werkzeug.wrappers": [ + "BaseResponse", + "BaseRequest", + "Request", + "Response", + "AcceptMixin", + "ETagRequestMixin", + "ETagResponseMixin", + "ResponseStreamMixin", + "CommonResponseDescriptorsMixin", + "UserAgentMixin", + "AuthorizationMixin", + "WWWAuthenticateMixin", + "CommonRequestDescriptorsMixin", + ], "werkzeug.middleware.dispatcher": ["DispatcherMiddleware"], "werkzeug.middleware.shared_data": ["SharedDataMiddleware"], - 'werkzeug.security': ['generate_password_hash', 'check_password_hash'], + "werkzeug.security": ["generate_password_hash", "check_password_hash"], # the undocumented easteregg ;-) - 'werkzeug._internal': ['_easteregg'] + "werkzeug._internal": ["_easteregg"], } # modules that should be imported when accessed as attributes of werkzeug -attribute_modules = frozenset(['exceptions', 'routing']) - +attribute_modules = frozenset(["exceptions", "routing"]) object_origins = {} -for module, items in iteritems(all_by_module): +for module, items in all_by_module.items(): for item in items: object_origins[item] = module class module(ModuleType): - """Automatically import objects from the modules.""" def __getattr__(self, name): @@ -119,34 +188,46 @@ class module(ModuleType): setattr(self, extra_name, getattr(module, extra_name)) return getattr(module, name) elif name in attribute_modules: - __import__('werkzeug.' + name) + __import__("werkzeug." + name) return ModuleType.__getattribute__(self, name) def __dir__(self): """Just show what we want to show.""" result = list(new_module.__all__) - result.extend(('__file__', '__doc__', '__all__', - '__docformat__', '__name__', '__path__', - '__package__', '__version__')) + result.extend( + ( + "__file__", + "__doc__", + "__all__", + "__docformat__", + "__name__", + "__path__", + "__package__", + "__version__", + ) + ) return result + # keep a reference to this module so that it's not garbage collected -old_module = sys.modules['werkzeug'] +old_module = sys.modules["werkzeug"] # setup the new module and patch it into the dict of loaded modules -new_module = sys.modules['werkzeug'] = module('werkzeug') -new_module.__dict__.update({ - '__file__': __file__, - '__package__': 'werkzeug', - '__path__': __path__, - '__doc__': __doc__, - '__version__': __version__, - '__all__': tuple(object_origins) + tuple(attribute_modules), - '__docformat__': 'restructuredtext en' -}) +new_module = sys.modules["werkzeug"] = module("werkzeug") +new_module.__dict__.update( + { + "__file__": __file__, + "__package__": "werkzeug", + "__path__": __path__, + "__doc__": __doc__, + "__version__": __version__, + "__all__": tuple(object_origins) + tuple(attribute_modules), + "__docformat__": "restructuredtext en", + } +) # Due to bootstrapping issues we need to import exceptions here. # Don't ask :-( -__import__('werkzeug.exceptions') +__import__("werkzeug.exceptions") diff --git a/src/werkzeug/_compat.py b/src/werkzeug/_compat.py index f9c5b343..1097983e 100644 --- a/src/werkzeug/_compat.py +++ b/src/werkzeug/_compat.py @@ -1,10 +1,8 @@ # flake8: noqa # This whole file is full of lint errors -import codecs -import sys -import operator import functools -import warnings +import operator +import sys try: import builtins @@ -13,7 +11,7 @@ except ImportError: PY2 = sys.version_info[0] == 2 -WIN = sys.platform.startswith('win') +WIN = sys.platform.startswith("win") _identity = lambda x: x @@ -35,15 +33,19 @@ if PY2: import collections as collections_abc - exec('def reraise(tp, value, tb=None):\n raise tp, value, tb') + exec("def reraise(tp, value, tb=None):\n raise tp, value, tb") def fix_tuple_repr(obj): def __repr__(self): cls = self.__class__ - return '%s(%s)' % (cls.__name__, ', '.join( - '%s=%r' % (field, self[index]) - for index, field in enumerate(cls._fields) - )) + return "%s(%s)" % ( + cls.__name__, + ", ".join( + "%s=%r" % (field, self[index]) + for index, field in enumerate(cls._fields) + ), + ) + obj.__repr__ = __repr__ return obj @@ -54,12 +56,13 @@ if PY2: def implements_to_string(cls): cls.__unicode__ = cls.__str__ - cls.__str__ = lambda x: x.__unicode__().encode('utf-8') + cls.__str__ = lambda x: x.__unicode__().encode("utf-8") return cls def native_string_result(func): def wrapper(*args, **kwargs): - return func(*args, **kwargs).encode('utf-8') + return func(*args, **kwargs).encode("utf-8") + return functools.update_wrapper(wrapper, func) def implements_bool(cls): @@ -68,10 +71,12 @@ if PY2: return cls from itertools import imap, izip, ifilter + range_type = xrange from StringIO import StringIO from cStringIO import StringIO as BytesIO + NativeStringIO = BytesIO def make_literal_wrapper(reference): @@ -96,33 +101,34 @@ if PY2: wsgi_get_bytes = _identity - def wsgi_decoding_dance(s, charset='utf-8', errors='replace'): + def wsgi_decoding_dance(s, charset="utf-8", errors="replace"): return s.decode(charset, errors) - def wsgi_encoding_dance(s, charset='utf-8', errors='replace'): + def wsgi_encoding_dance(s, charset="utf-8", errors="replace"): if isinstance(s, bytes): return s return s.encode(charset, errors) - def to_bytes(x, charset=sys.getdefaultencoding(), errors='strict'): + def to_bytes(x, charset=sys.getdefaultencoding(), errors="strict"): if x is None: return None if isinstance(x, (bytes, bytearray, buffer)): return bytes(x) if isinstance(x, unicode): return x.encode(charset, errors) - raise TypeError('Expected bytes') + raise TypeError("Expected bytes") - def to_native(x, charset=sys.getdefaultencoding(), errors='strict'): + def to_native(x, charset=sys.getdefaultencoding(), errors="strict"): if x is None or isinstance(x, str): return x return x.encode(charset, errors) + else: unichr = chr text_type = str - string_types = (str, ) - integer_types = (int, ) + string_types = (str,) + integer_types = (int,) iterkeys = lambda d, *args, **kwargs: iter(d.keys(*args, **kwargs)) itervalues = lambda d, *args, **kwargs: iter(d.values(*args, **kwargs)) @@ -131,7 +137,7 @@ else: iterlists = lambda d, *args, **kwargs: iter(d.lists(*args, **kwargs)) iterlistvalues = lambda d, *args, **kwargs: iter(d.listvalues(*args, **kwargs)) - int_to_byte = operator.methodcaller('to_bytes', 1, 'big') + int_to_byte = operator.methodcaller("to_bytes", 1, "big") iter_bytes = functools.partial(map, int_to_byte) import collections.abc as collections_abc @@ -152,9 +158,10 @@ else: range_type = range from io import StringIO, BytesIO + NativeStringIO = StringIO - _latin1_encode = operator.methodcaller('encode', 'latin1') + _latin1_encode = operator.methodcaller("encode", "latin1") def make_literal_wrapper(reference): if isinstance(reference, text_type): @@ -169,38 +176,40 @@ else: is_text = isinstance(next(tupiter, None), text_type) for arg in tupiter: if isinstance(arg, text_type) != is_text: - raise TypeError('Cannot mix str and bytes arguments (got %s)' - % repr(tup)) + raise TypeError( + "Cannot mix str and bytes arguments (got %s)" % repr(tup) + ) return tup try_coerce_native = _identity wsgi_get_bytes = _latin1_encode - def wsgi_decoding_dance(s, charset='utf-8', errors='replace'): - return s.encode('latin1').decode(charset, errors) + def wsgi_decoding_dance(s, charset="utf-8", errors="replace"): + return s.encode("latin1").decode(charset, errors) - def wsgi_encoding_dance(s, charset='utf-8', errors='replace'): + def wsgi_encoding_dance(s, charset="utf-8", errors="replace"): if isinstance(s, text_type): s = s.encode(charset) - return s.decode('latin1', errors) + return s.decode("latin1", errors) - def to_bytes(x, charset=sys.getdefaultencoding(), errors='strict'): + def to_bytes(x, charset=sys.getdefaultencoding(), errors="strict"): if x is None: return None if isinstance(x, (bytes, bytearray, memoryview)): # noqa return bytes(x) if isinstance(x, str): return x.encode(charset, errors) - raise TypeError('Expected bytes') + raise TypeError("Expected bytes") - def to_native(x, charset=sys.getdefaultencoding(), errors='strict'): + def to_native(x, charset=sys.getdefaultencoding(), errors="strict"): if x is None or isinstance(x, str): return x return x.decode(charset, errors) -def to_unicode(x, charset=sys.getdefaultencoding(), errors='strict', - allow_none_charset=False): +def to_unicode( + x, charset=sys.getdefaultencoding(), errors="strict", allow_none_charset=False +): if x is None: return None if not isinstance(x, bytes): diff --git a/src/werkzeug/_internal.py b/src/werkzeug/_internal.py index 36084438..d8b83363 100644 --- a/src/werkzeug/_internal.py +++ b/src/werkzeug/_internal.py @@ -8,41 +8,46 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ +import inspect import re import string -import inspect -from weakref import WeakKeyDictionary -from datetime import datetime, date +from datetime import date +from datetime import datetime from itertools import chain +from weakref import WeakKeyDictionary -from werkzeug._compat import iter_bytes, text_type, int_to_byte, range_type, \ - integer_types +from ._compat import int_to_byte +from ._compat import integer_types +from ._compat import iter_bytes +from ._compat import range_type +from ._compat import text_type _logger = None _signature_cache = WeakKeyDictionary() _epoch_ord = date(1970, 1, 1).toordinal() -_cookie_params = set((b'expires', b'path', b'comment', - b'max-age', b'secure', b'httponly', - b'version')) -_legal_cookie_chars = (string.ascii_letters - + string.digits - + u"/=!#$%&'*+-.^_`|~:").encode('ascii') - -_cookie_quoting_map = { - b',': b'\\054', - b';': b'\\073', - b'"': b'\\"', - b'\\': b'\\\\', +_cookie_params = { + b"expires", + b"path", + b"comment", + b"max-age", + b"secure", + b"httponly", + b"version", } -for _i in chain(range_type(32), range_type(127, 256)): - _cookie_quoting_map[int_to_byte(_i)] = ('\\%03o' % _i).encode('latin1') +_legal_cookie_chars = ( + string.ascii_letters + string.digits + u"/=!#$%&'*+-.^_`|~:" +).encode("ascii") +_cookie_quoting_map = {b",": b"\\054", b";": b"\\073", b'"': b'\\"', b"\\": b"\\\\"} +for _i in chain(range_type(32), range_type(127, 256)): + _cookie_quoting_map[int_to_byte(_i)] = ("\\%03o" % _i).encode("latin1") -_octal_re = re.compile(br'\\[0-3][0-7][0-7]') -_quote_re = re.compile(br'[\\].') -_legal_cookie_chars_re = br'[\w\d!#%&\'~_`><@,:/\$\*\+\-\.\^\|\)\(\?\}\{\=]' -_cookie_re = re.compile(br""" +_octal_re = re.compile(br"\\[0-3][0-7][0-7]") +_quote_re = re.compile(br"[\\].") +_legal_cookie_chars_re = br"[\w\d!#%&\'~_`><@,:/\$\*\+\-\.\^\|\)\(\?\}\{\=]" +_cookie_re = re.compile( + br""" (?P<key>[^=;]+) (?:\s*=\s* (?P<val> @@ -51,24 +56,27 @@ _cookie_re = re.compile(br""" ) )? \s*; -""", flags=re.VERBOSE) +""", + flags=re.VERBOSE, +) class _Missing(object): - def __repr__(self): - return 'no value' + return "no value" def __reduce__(self): - return '_missing' + return "_missing" + _missing = _Missing() def _get_environ(obj): - env = getattr(obj, 'environ', obj) - assert isinstance(env, dict), \ - '%r is not a WSGI environment (has to be a dict)' % type(obj).__name__ + env = getattr(obj, "environ", obj) + assert isinstance(env, dict), ( + "%r is not a WSGI environment (has to be a dict)" % type(obj).__name__ + ) return env @@ -77,7 +85,8 @@ def _log(type, message, *args, **kwargs): global _logger if _logger is None: import logging - _logger = logging.getLogger('werkzeug') + + _logger = logging.getLogger("werkzeug") if _logger.level == logging.NOTSET: _logger.setLevel(logging.INFO) # Only set up a default log handler if the @@ -90,7 +99,7 @@ def _log(type, message, *args, **kwargs): def _parse_signature(func): """Return a signature object for the function.""" - if hasattr(func, 'im_func'): + if hasattr(func, "im_func"): func = func.im_func # if we have a cached validator for this function, return it @@ -99,7 +108,7 @@ def _parse_signature(func): return parse # inspect the function signature and collect all the information - if hasattr(inspect, 'getfullargspec'): + if hasattr(inspect, "getfullargspec"): tup = inspect.getfullargspec(func) else: tup = inspect.getargspec(func) @@ -109,8 +118,9 @@ def _parse_signature(func): arguments = [] for idx, name in enumerate(positional): if isinstance(name, list): - raise TypeError('cannot parse functions that unpack tuples ' - 'in the function signature') + raise TypeError( + "cannot parse functions that unpack tuples in the function signature" + ) try: default = defaults[idx - arg_count] except IndexError: @@ -150,8 +160,17 @@ def _parse_signature(func): extra.update(kwargs) kwargs = {} - return new_args, kwargs, missing, extra, extra_positional, \ - arguments, vararg_var, kwarg_var + return ( + new_args, + kwargs, + missing, + extra, + extra_positional, + arguments, + vararg_var, + kwarg_var, + ) + _signature_cache[func] = parse return parse @@ -173,12 +192,19 @@ def _date_to_unix(arg): class _DictAccessorProperty(object): - """Baseclass for `environ_property` and `header_property`.""" + read_only = False - def __init__(self, name, default=None, load_func=None, dump_func=None, - read_only=None, doc=None): + def __init__( + self, + name, + default=None, + load_func=None, + dump_func=None, + read_only=None, + doc=None, + ): self.name = name self.default = default self.load_func = load_func @@ -203,21 +229,18 @@ class _DictAccessorProperty(object): def __set__(self, obj, value): if self.read_only: - raise AttributeError('read only property') + raise AttributeError("read only property") if self.dump_func is not None: value = self.dump_func(value) self.lookup(obj)[self.name] = value def __delete__(self, obj): if self.read_only: - raise AttributeError('read only property') + raise AttributeError("read only property") self.lookup(obj).pop(self.name, None) def __repr__(self): - return '<%s %s>' % ( - self.__class__.__name__, - self.name - ) + return "<%s %s>" % (self.__class__.__name__, self.name) def _cookie_quote(b): @@ -263,11 +286,11 @@ def _cookie_unquote(b): k = q_match.start(0) if q_match and (not o_match or k < j): _push(b[i:k]) - _push(b[k + 1:k + 2]) + _push(b[k + 1 : k + 2]) i = k + 2 else: _push(b[i:j]) - rv.append(int(b[j + 1:j + 4], 8)) + rv.append(int(b[j + 1 : j + 4], 8)) i = j + 4 return bytes(rv) @@ -279,12 +302,12 @@ def _cookie_parse_impl(b): n = len(b) while i < n: - match = _cookie_re.search(b + b';', i) + match = _cookie_re.search(b + b";", i) if not match: break - key = match.group('key').strip() - value = match.group('val') or b'' + key = match.group("key").strip() + value = match.group("val") or b"" i = match.end(0) # Ignore parameters. We have no interest in them. @@ -295,20 +318,20 @@ def _cookie_parse_impl(b): def _encode_idna(domain): # If we're given bytes, make sure they fit into ASCII if not isinstance(domain, text_type): - domain.decode('ascii') + domain.decode("ascii") return domain # Otherwise check if it's already ascii, then return try: - return domain.encode('ascii') + return domain.encode("ascii") except UnicodeError: pass # Otherwise encode each part separately - parts = domain.split('.') + parts = domain.split(".") for idx, part in enumerate(parts): - parts[idx] = part.encode('idna') - return b'.'.join(parts) + parts[idx] = part.encode("idna") + return b".".join(parts) def _decode_idna(domain): @@ -317,47 +340,54 @@ def _decode_idna(domain): # unicode error, then we already have a decoded idna domain if isinstance(domain, text_type): try: - domain = domain.encode('ascii') + domain = domain.encode("ascii") except UnicodeError: return domain # Decode each part separately. If a part fails, try to # decode it with ascii and silently ignore errors. This makes # most sense because the idna codec does not have error handling - parts = domain.split(b'.') + parts = domain.split(b".") for idx, part in enumerate(parts): try: - parts[idx] = part.decode('idna') + parts[idx] = part.decode("idna") except UnicodeError: - parts[idx] = part.decode('ascii', 'ignore') + parts[idx] = part.decode("ascii", "ignore") - return '.'.join(parts) + return ".".join(parts) def _make_cookie_domain(domain): if domain is None: return None domain = _encode_idna(domain) - if b':' in domain: - domain = domain.split(b':', 1)[0] - if b'.' in domain: + if b":" in domain: + domain = domain.split(b":", 1)[0] + if b"." in domain: return domain raise ValueError( - 'Setting \'domain\' for a cookie on a server running locally (ex: ' - 'localhost) is not supported by complying browsers. You should ' - 'have something like: \'127.0.0.1 localhost dev.localhost\' on ' - 'your hosts file and then point your server to run on ' - '\'dev.localhost\' and also set \'domain\' for \'dev.localhost\'' + "Setting 'domain' for a cookie on a server running locally (ex: " + "localhost) is not supported by complying browsers. You should " + "have something like: '127.0.0.1 localhost dev.localhost' on " + "your hosts file and then point your server to run on " + "'dev.localhost' and also set 'domain' for 'dev.localhost'" ) def _easteregg(app=None): """Like the name says. But who knows how it works?""" + def bzzzzzzz(gyver): import base64 import zlib - return zlib.decompress(base64.b64decode(gyver)).decode('ascii') - gyver = u'\n'.join([x + (77 - len(x)) * u' ' for x in bzzzzzzz(b''' + + return zlib.decompress(base64.b64decode(gyver)).decode("ascii") + + gyver = u"\n".join( + [ + x + (77 - len(x)) * u" " + for x in bzzzzzzz( + b""" eJyFlzuOJDkMRP06xRjymKgDJCDQStBYT8BCgK4gTwfQ2fcFs2a2FzvZk+hvlcRvRJD148efHt9m 9Xz94dRY5hGt1nrYcXx7us9qlcP9HHNh28rz8dZj+q4rynVFFPdlY4zH873NKCexrDM6zxxRymzz 4QIxzK4bth1PV7+uHn6WXZ5C4ka/+prFzx3zWLMHAVZb8RRUxtFXI5DTQ2n3Hi2sNI+HK43AOWSY @@ -388,16 +418,22 @@ p1qXK3Du2mnr5INXmT/78KI12n11EFBkJHHp0wJyLe9MvPNUGYsf+170maayRoy2lURGHAIapSpQ krEDuNoJCHNlZYhKpvw4mspVWxqo415n8cD62N9+EfHrAvqQnINStetek7RY2Urv8nxsnGaZfRr/ nhXbJ6m/yl1LzYqscDZA9QHLNbdaSTTr+kFg3bC0iYbX/eQy0Bv3h4B50/SGYzKAXkCeOLI3bcAt mj2Z/FM1vQWgDynsRwNvrWnJHlespkrp8+vO1jNaibm+PhqXPPv30YwDZ6jApe3wUjFQobghvW9p -7f2zLkGNv8b191cD/3vs9Q833z8t''').splitlines()]) +7f2zLkGNv8b191cD/3vs9Q833z8t""" + ).splitlines() + ] + ) def easteregged(environ, start_response): def injecting_start_response(status, headers, exc_info=None): - headers.append(('X-Powered-By', 'Werkzeug')) + headers.append(("X-Powered-By", "Werkzeug")) return start_response(status, headers, exc_info) - if app is not None and environ.get('QUERY_STRING') != 'macgybarchakku': + + if app is not None and environ.get("QUERY_STRING") != "macgybarchakku": return app(environ, injecting_start_response) - injecting_start_response('200 OK', [('Content-Type', 'text/html')]) - return [(u''' + injecting_start_response("200 OK", [("Content-Type", "text/html")]) + return [ + ( + u""" <!DOCTYPE html> <html> <head> @@ -415,5 +451,9 @@ mj2Z/FM1vQWgDynsRwNvrWnJHlespkrp8+vO1jNaibm+PhqXPPv30YwDZ6jApe3wUjFQobghvW9p <p>the Swiss Army knife of Python web development.</p> <pre>%s\n\n\n</pre> </body> -</html>''' % gyver).encode('latin1')] +</html>""" + % gyver + ).encode("latin1") + ] + return easteregged diff --git a/src/werkzeug/_reloader.py b/src/werkzeug/_reloader.py index 2b21e146..f06a63d5 100644 --- a/src/werkzeug/_reloader.py +++ b/src/werkzeug/_reloader.py @@ -1,12 +1,14 @@ import os -import sys -import time import subprocess +import sys import threading +import time from itertools import chain -from werkzeug._internal import _log -from werkzeug._compat import PY2, iteritems, text_type +from ._compat import iteritems +from ._compat import PY2 +from ._compat import text_type +from ._internal import _log def _iter_module_files(): @@ -19,10 +21,11 @@ def _iter_module_files(): for module in list(sys.modules.values()): if module is None: continue - filename = getattr(module, '__file__', None) + filename = getattr(module, "__file__", None) if filename: - if os.path.isdir(filename) and \ - os.path.exists(os.path.join(filename, "__init__.py")): + if os.path.isdir(filename) and os.path.exists( + os.path.join(filename, "__init__.py") + ): filename = os.path.join(filename, "__init__.py") old = None @@ -32,22 +35,23 @@ def _iter_module_files(): if filename == old: break else: - if filename[-4:] in ('.pyc', '.pyo'): + if filename[-4:] in (".pyc", ".pyo"): filename = filename[:-1] yield filename def _find_observable_paths(extra_files=None): """Finds all paths that should be observed.""" - rv = set(os.path.dirname(os.path.abspath(x)) - if os.path.isfile(x) else os.path.abspath(x) - for x in sys.path) + rv = set( + os.path.dirname(os.path.abspath(x)) if os.path.isfile(x) else os.path.abspath(x) + for x in sys.path + ) for filename in extra_files or (): rv.add(os.path.dirname(os.path.abspath(filename))) for module in list(sys.modules.values()): - fn = getattr(module, '__file__', None) + fn = getattr(module, "__file__", None) if fn is None: continue fn = os.path.abspath(fn) @@ -125,7 +129,8 @@ def _find_common_roots(paths): for prefix, child in iteritems(node): _walk(child, path + (prefix,)) if not node: - rv.add('/'.join(path)) + rv.add("/".join(path)) + _walk(root, ()) return rv @@ -139,8 +144,7 @@ class ReloaderLoop(object): _sleep = staticmethod(time.sleep) def __init__(self, extra_files=None, interval=1): - self.extra_files = set(os.path.abspath(x) - for x in extra_files or ()) + self.extra_files = set(os.path.abspath(x) for x in extra_files or ()) self.interval = interval def run(self): @@ -151,26 +155,25 @@ class ReloaderLoop(object): but running the reloader thread. """ while 1: - _log('info', ' * Restarting with %s' % self.name) + _log("info", " * Restarting with %s" % self.name) args = _get_args_for_reloading() # a weird bug on windows. sometimes unicode strings end up in the # environment and subprocess.call does not like this, encode them # to latin1 and continue. - if os.name == 'nt' and PY2: + if os.name == "nt" and PY2: new_environ = {} for key, value in iteritems(os.environ): if isinstance(key, text_type): - key = key.encode('iso-8859-1') + key = key.encode("iso-8859-1") if isinstance(value, text_type): - value = value.encode('iso-8859-1') + value = value.encode("iso-8859-1") new_environ[key] = value else: new_environ = os.environ.copy() - new_environ['WERKZEUG_RUN_MAIN'] = 'true' - exit_code = subprocess.call(args, env=new_environ, - close_fds=False) + new_environ["WERKZEUG_RUN_MAIN"] = "true" + exit_code = subprocess.call(args, env=new_environ, close_fds=False) if exit_code != 3: return exit_code @@ -180,17 +183,16 @@ class ReloaderLoop(object): def log_reload(self, filename): filename = os.path.abspath(filename) - _log('info', ' * Detected change in %r, reloading' % filename) + _log("info", " * Detected change in %r, reloading" % filename) class StatReloaderLoop(ReloaderLoop): - name = 'stat' + name = "stat" def run(self): mtimes = {} while 1: - for filename in chain(_iter_module_files(), - self.extra_files): + for filename in chain(_iter_module_files(), self.extra_files): try: mtime = os.stat(filename).st_mtime except OSError: @@ -206,11 +208,11 @@ class StatReloaderLoop(ReloaderLoop): class WatchdogReloaderLoop(ReloaderLoop): - def __init__(self, *args, **kwargs): ReloaderLoop.__init__(self, *args, **kwargs) from watchdog.observers import Observer from watchdog.events import FileSystemEventHandler + self.observable_paths = set() def _check_modification(filename): @@ -218,11 +220,10 @@ class WatchdogReloaderLoop(ReloaderLoop): self.trigger_reload(filename) dirname = os.path.dirname(filename) if dirname.startswith(tuple(self.observable_paths)): - if filename.endswith(('.pyc', '.pyo', '.py')): + if filename.endswith((".pyc", ".pyo", ".py")): self.trigger_reload(filename) class _CustomHandler(FileSystemEventHandler): - def on_created(self, event): _check_modification(event.src_path) @@ -237,9 +238,9 @@ class WatchdogReloaderLoop(ReloaderLoop): _check_modification(event.src_path) reloader_name = Observer.__name__.lower() - if reloader_name.endswith('observer'): + if reloader_name.endswith("observer"): reloader_name = reloader_name[:-8] - reloader_name += ' reloader' + reloader_name += " reloader" self.name = reloader_name @@ -267,7 +268,8 @@ class WatchdogReloaderLoop(ReloaderLoop): if path not in watches: try: watches[path] = observer.schedule( - self.event_handler, path, recursive=True) + self.event_handler, path, recursive=True + ) except OSError: # Clear this path from list of watches We don't want # the same error message showing again in the next @@ -287,17 +289,14 @@ class WatchdogReloaderLoop(ReloaderLoop): sys.exit(3) -reloader_loops = { - 'stat': StatReloaderLoop, - 'watchdog': WatchdogReloaderLoop, -} +reloader_loops = {"stat": StatReloaderLoop, "watchdog": WatchdogReloaderLoop} try: - __import__('watchdog.observers') + __import__("watchdog.observers") except ImportError: - reloader_loops['auto'] = reloader_loops['stat'] + reloader_loops["auto"] = reloader_loops["stat"] else: - reloader_loops['auto'] = reloader_loops['watchdog'] + reloader_loops["auto"] = reloader_loops["watchdog"] def ensure_echo_on(): @@ -316,14 +315,14 @@ def ensure_echo_on(): termios.tcsetattr(sys.stdin, termios.TCSANOW, attributes) -def run_with_reloader(main_func, extra_files=None, interval=1, - reloader_type='auto'): +def run_with_reloader(main_func, extra_files=None, interval=1, reloader_type="auto"): """Run the given function in an independent python interpreter.""" import signal + reloader = reloader_loops[reloader_type](extra_files, interval) signal.signal(signal.SIGTERM, lambda *args: sys.exit(0)) try: - if os.environ.get('WERKZEUG_RUN_MAIN') == 'true': + if os.environ.get("WERKZEUG_RUN_MAIN") == "true": ensure_echo_on() t = threading.Thread(target=main_func, args=()) t.setDaemon(True) diff --git a/src/werkzeug/contrib/atom.py b/src/werkzeug/contrib/atom.py index 5c631f43..d079d2bf 100644 --- a/src/werkzeug/contrib/atom.py +++ b/src/werkzeug/contrib/atom.py @@ -23,9 +23,11 @@ """ import warnings from datetime import datetime -from werkzeug.utils import escape -from werkzeug.wrappers import BaseResponse -from werkzeug._compat import implements_to_string, string_types + +from .._compat import implements_to_string +from .._compat import string_types +from ..utils import escape +from ..wrappers import BaseResponse warnings.warn( "'werkzeug.contrib.atom' is deprecated as of version 0.15 and will" @@ -34,18 +36,21 @@ warnings.warn( stacklevel=2, ) -XHTML_NAMESPACE = 'http://www.w3.org/1999/xhtml' +XHTML_NAMESPACE = "http://www.w3.org/1999/xhtml" def _make_text_block(name, content, content_type=None): """Helper function for the builder that creates an XML text block.""" - if content_type == 'xhtml': - return u'<%s type="xhtml"><div xmlns="%s">%s</div></%s>\n' % \ - (name, XHTML_NAMESPACE, content, name) + if content_type == "xhtml": + return u'<%s type="xhtml"><div xmlns="%s">%s</div></%s>\n' % ( + name, + XHTML_NAMESPACE, + content, + name, + ) if not content_type: - return u'<%s>%s</%s>\n' % (name, escape(content), name) - return u'<%s type="%s">%s</%s>\n' % (name, content_type, - escape(content), name) + return u"<%s>%s</%s>\n" % (name, escape(content), name) + return u'<%s type="%s">%s</%s>\n' % (name, content_type, escape(content), name) def format_iso8601(obj): @@ -53,7 +58,7 @@ def format_iso8601(obj): iso8601 = obj.isoformat() if obj.tzinfo: return iso8601 - return iso8601 + 'Z' + return iso8601 + "Z" @implements_to_string @@ -106,42 +111,43 @@ class AtomFeed(object): Everywhere where a list is demanded, any iterable can be used. """ - default_generator = ('Werkzeug', None, None) + default_generator = ("Werkzeug", None, None) def __init__(self, title=None, entries=None, **kwargs): self.title = title - self.title_type = kwargs.get('title_type', 'text') - self.url = kwargs.get('url') - self.feed_url = kwargs.get('feed_url', self.url) - self.id = kwargs.get('id', self.feed_url) - self.updated = kwargs.get('updated') - self.author = kwargs.get('author', ()) - self.icon = kwargs.get('icon') - self.logo = kwargs.get('logo') - self.rights = kwargs.get('rights') - self.rights_type = kwargs.get('rights_type') - self.subtitle = kwargs.get('subtitle') - self.subtitle_type = kwargs.get('subtitle_type', 'text') - self.generator = kwargs.get('generator') + self.title_type = kwargs.get("title_type", "text") + self.url = kwargs.get("url") + self.feed_url = kwargs.get("feed_url", self.url) + self.id = kwargs.get("id", self.feed_url) + self.updated = kwargs.get("updated") + self.author = kwargs.get("author", ()) + self.icon = kwargs.get("icon") + self.logo = kwargs.get("logo") + self.rights = kwargs.get("rights") + self.rights_type = kwargs.get("rights_type") + self.subtitle = kwargs.get("subtitle") + self.subtitle_type = kwargs.get("subtitle_type", "text") + self.generator = kwargs.get("generator") if self.generator is None: self.generator = self.default_generator - self.links = kwargs.get('links', []) + self.links = kwargs.get("links", []) self.entries = list(entries) if entries else [] - if not hasattr(self.author, '__iter__') \ - or isinstance(self.author, string_types + (dict,)): + if not hasattr(self.author, "__iter__") or isinstance( + self.author, string_types + (dict,) + ): self.author = [self.author] for i, author in enumerate(self.author): if not isinstance(author, dict): - self.author[i] = {'name': author} + self.author[i] = {"name": author} if not self.title: - raise ValueError('title is required') + raise ValueError("title is required") if not self.id: - raise ValueError('id is required') + raise ValueError("id is required") for author in self.author: - if 'name' not in author: - raise TypeError('author must contain at least a name') + if "name" not in author: + raise TypeError("author must contain at least a name") def add(self, *args, **kwargs): """Add a new entry to the feed. This function can either be called @@ -151,14 +157,14 @@ class AtomFeed(object): if len(args) == 1 and not kwargs and isinstance(args[0], FeedEntry): self.entries.append(args[0]) else: - kwargs['feed_url'] = self.feed_url + kwargs["feed_url"] = self.feed_url self.entries.append(FeedEntry(*args, **kwargs)) def __repr__(self): - return '<%s %r (%d entries)>' % ( + return "<%s %r (%d entries)>" % ( self.__class__.__name__, self.title, - len(self.entries) + len(self.entries), ) def generate(self): @@ -166,7 +172,7 @@ class AtomFeed(object): # atom demands either an author element in every entry or a global one if not self.author: if any(not e.author for e in self.entries): - self.author = ({'name': 'Unknown author'},) + self.author = ({"name": "Unknown author"},) if not self.updated: dates = sorted([entry.updated for entry in self.entries]) @@ -174,56 +180,54 @@ class AtomFeed(object): yield u'<?xml version="1.0" encoding="utf-8"?>\n' yield u'<feed xmlns="http://www.w3.org/2005/Atom">\n' - yield ' ' + _make_text_block('title', self.title, self.title_type) - yield u' <id>%s</id>\n' % escape(self.id) - yield u' <updated>%s</updated>\n' % format_iso8601(self.updated) + yield " " + _make_text_block("title", self.title, self.title_type) + yield u" <id>%s</id>\n" % escape(self.id) + yield u" <updated>%s</updated>\n" % format_iso8601(self.updated) if self.url: yield u' <link href="%s" />\n' % escape(self.url) if self.feed_url: - yield u' <link href="%s" rel="self" />\n' % \ - escape(self.feed_url) + yield u' <link href="%s" rel="self" />\n' % escape(self.feed_url) for link in self.links: - yield u' <link %s/>\n' % ''.join('%s="%s" ' % - (k, escape(link[k])) for k in link) + yield u" <link %s/>\n" % "".join( + '%s="%s" ' % (k, escape(link[k])) for k in link + ) for author in self.author: - yield u' <author>\n' - yield u' <name>%s</name>\n' % escape(author['name']) - if 'uri' in author: - yield u' <uri>%s</uri>\n' % escape(author['uri']) - if 'email' in author: - yield ' <email>%s</email>\n' % escape(author['email']) - yield ' </author>\n' + yield u" <author>\n" + yield u" <name>%s</name>\n" % escape(author["name"]) + if "uri" in author: + yield u" <uri>%s</uri>\n" % escape(author["uri"]) + if "email" in author: + yield " <email>%s</email>\n" % escape(author["email"]) + yield " </author>\n" if self.subtitle: - yield ' ' + _make_text_block('subtitle', self.subtitle, - self.subtitle_type) + yield " " + _make_text_block("subtitle", self.subtitle, self.subtitle_type) if self.icon: - yield u' <icon>%s</icon>\n' % escape(self.icon) + yield u" <icon>%s</icon>\n" % escape(self.icon) if self.logo: - yield u' <logo>%s</logo>\n' % escape(self.logo) + yield u" <logo>%s</logo>\n" % escape(self.logo) if self.rights: - yield ' ' + _make_text_block('rights', self.rights, - self.rights_type) + yield " " + _make_text_block("rights", self.rights, self.rights_type) generator_name, generator_url, generator_version = self.generator if generator_name or generator_url or generator_version: - tmp = [u' <generator'] + tmp = [u" <generator"] if generator_url: tmp.append(u' uri="%s"' % escape(generator_url)) if generator_version: tmp.append(u' version="%s"' % escape(generator_version)) - tmp.append(u'>%s</generator>\n' % escape(generator_name)) - yield u''.join(tmp) + tmp.append(u">%s</generator>\n" % escape(generator_name)) + yield u"".join(tmp) for entry in self.entries: for line in entry.generate(): - yield u' ' + line - yield u'</feed>\n' + yield u" " + line + yield u"</feed>\n" def to_string(self): """Convert the feed into a string.""" - return u''.join(self.generate()) + return u"".join(self.generate()) def get_response(self): """Return a response object for the feed.""" - return BaseResponse(self.to_string(), mimetype='application/atom+xml') + return BaseResponse(self.to_string(), mimetype="application/atom+xml") def __call__(self, environ, start_response): """Use the class as WSGI response object.""" @@ -282,80 +286,77 @@ class FeedEntry(object): def __init__(self, title=None, content=None, feed_url=None, **kwargs): self.title = title - self.title_type = kwargs.get('title_type', 'text') + self.title_type = kwargs.get("title_type", "text") self.content = content - self.content_type = kwargs.get('content_type', 'html') - self.url = kwargs.get('url') - self.id = kwargs.get('id', self.url) - self.updated = kwargs.get('updated') - self.summary = kwargs.get('summary') - self.summary_type = kwargs.get('summary_type', 'html') - self.author = kwargs.get('author', ()) - self.published = kwargs.get('published') - self.rights = kwargs.get('rights') - self.links = kwargs.get('links', []) - self.categories = kwargs.get('categories', []) - self.xml_base = kwargs.get('xml_base', feed_url) - - if not hasattr(self.author, '__iter__') \ - or isinstance(self.author, string_types + (dict,)): + self.content_type = kwargs.get("content_type", "html") + self.url = kwargs.get("url") + self.id = kwargs.get("id", self.url) + self.updated = kwargs.get("updated") + self.summary = kwargs.get("summary") + self.summary_type = kwargs.get("summary_type", "html") + self.author = kwargs.get("author", ()) + self.published = kwargs.get("published") + self.rights = kwargs.get("rights") + self.links = kwargs.get("links", []) + self.categories = kwargs.get("categories", []) + self.xml_base = kwargs.get("xml_base", feed_url) + + if not hasattr(self.author, "__iter__") or isinstance( + self.author, string_types + (dict,) + ): self.author = [self.author] for i, author in enumerate(self.author): if not isinstance(author, dict): - self.author[i] = {'name': author} + self.author[i] = {"name": author} if not self.title: - raise ValueError('title is required') + raise ValueError("title is required") if not self.id: - raise ValueError('id is required') + raise ValueError("id is required") if not self.updated: - raise ValueError('updated is required') + raise ValueError("updated is required") def __repr__(self): - return '<%s %r>' % ( - self.__class__.__name__, - self.title - ) + return "<%s %r>" % (self.__class__.__name__, self.title) def generate(self): """Yields pieces of ATOM XML.""" - base = '' + base = "" if self.xml_base: base = ' xml:base="%s"' % escape(self.xml_base) - yield u'<entry%s>\n' % base - yield u' ' + _make_text_block('title', self.title, self.title_type) - yield u' <id>%s</id>\n' % escape(self.id) - yield u' <updated>%s</updated>\n' % format_iso8601(self.updated) + yield u"<entry%s>\n" % base + yield u" " + _make_text_block("title", self.title, self.title_type) + yield u" <id>%s</id>\n" % escape(self.id) + yield u" <updated>%s</updated>\n" % format_iso8601(self.updated) if self.published: - yield u' <published>%s</published>\n' % \ - format_iso8601(self.published) + yield u" <published>%s</published>\n" % format_iso8601(self.published) if self.url: yield u' <link href="%s" />\n' % escape(self.url) for author in self.author: - yield u' <author>\n' - yield u' <name>%s</name>\n' % escape(author['name']) - if 'uri' in author: - yield u' <uri>%s</uri>\n' % escape(author['uri']) - if 'email' in author: - yield u' <email>%s</email>\n' % escape(author['email']) - yield u' </author>\n' + yield u" <author>\n" + yield u" <name>%s</name>\n" % escape(author["name"]) + if "uri" in author: + yield u" <uri>%s</uri>\n" % escape(author["uri"]) + if "email" in author: + yield u" <email>%s</email>\n" % escape(author["email"]) + yield u" </author>\n" for link in self.links: - yield u' <link %s/>\n' % ''.join('%s="%s" ' % - (k, escape(link[k])) for k in link) + yield u" <link %s/>\n" % "".join( + '%s="%s" ' % (k, escape(link[k])) for k in link + ) for category in self.categories: - yield u' <category %s/>\n' % ''.join('%s="%s" ' % - (k, escape(category[k])) for k in category) + yield u" <category %s/>\n" % "".join( + '%s="%s" ' % (k, escape(category[k])) for k in category + ) if self.summary: - yield u' ' + _make_text_block('summary', self.summary, - self.summary_type) + yield u" " + _make_text_block("summary", self.summary, self.summary_type) if self.content: - yield u' ' + _make_text_block('content', self.content, - self.content_type) - yield u'</entry>\n' + yield u" " + _make_text_block("content", self.content, self.content_type) + yield u"</entry>\n" def to_string(self): """Convert the feed item into a unicode object.""" - return u''.join(self.generate()) + return u"".join(self.generate()) def __str__(self): return self.to_string() diff --git a/src/werkzeug/contrib/cache.py b/src/werkzeug/contrib/cache.py index 8f5a33ae..79c749b5 100644 --- a/src/werkzeug/contrib/cache.py +++ b/src/werkzeug/contrib/cache.py @@ -56,23 +56,27 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ +import errno import os +import platform import re -import errno import tempfile -import platform import warnings from hashlib import md5 from time import time + +from .._compat import integer_types +from .._compat import iteritems +from .._compat import string_types +from .._compat import text_type +from .._compat import to_native +from ..posixemulation import rename + try: import cPickle as pickle except ImportError: # pragma: no cover import pickle -from werkzeug._compat import iteritems, string_types, text_type, \ - integer_types, to_native -from werkzeug.posixemulation import rename - warnings.warn( "'werkzeug.contrib.cache' is deprecated as of version 0.15 and will" " be removed in version 1.0. It has moved to https://github.com" @@ -93,13 +97,12 @@ def _items(mappingorseq): ... assert k*k == v """ - if hasattr(mappingorseq, 'items'): + if hasattr(mappingorseq, "items"): return iteritems(mappingorseq) return mappingorseq class BaseCache(object): - """Baseclass for the cache systems. All the cache systems implement this API or a superset of it. @@ -224,10 +227,10 @@ class BaseCache(object): :param key: the key to check """ raise NotImplementedError( - '%s doesn\'t have an efficient implementation of `has`. That ' - 'means it is impossible to check whether a key exists without ' - 'fully loading the key\'s data. Consider using `self.get` ' - 'explicitly if you don\'t care about performance.' + "%s doesn't have an efficient implementation of `has`. That " + "means it is impossible to check whether a key exists without " + "fully loading the key's data. Consider using `self.get` " + "explicitly if you don't care about performance." ) def clear(self): @@ -267,7 +270,6 @@ class BaseCache(object): class NullCache(BaseCache): - """A cache that doesn't cache. This can be useful for unit testing. :param default_timeout: a dummy parameter that is ignored but exists @@ -279,7 +281,6 @@ class NullCache(BaseCache): class SimpleCache(BaseCache): - """Simple memory cache for single process environments. This class exists mainly for the development server and is not 100% thread safe. It tries to use as many atomic operations as possible and no locks for simplicity @@ -325,15 +326,13 @@ class SimpleCache(BaseCache): def set(self, key, value, timeout=None): expires = self._normalize_timeout(timeout) self._prune() - self._cache[key] = (expires, pickle.dumps(value, - pickle.HIGHEST_PROTOCOL)) + self._cache[key] = (expires, pickle.dumps(value, pickle.HIGHEST_PROTOCOL)) return True def add(self, key, value, timeout=None): expires = self._normalize_timeout(timeout) self._prune() - item = (expires, pickle.dumps(value, - pickle.HIGHEST_PROTOCOL)) + item = (expires, pickle.dumps(value, pickle.HIGHEST_PROTOCOL)) if key in self._cache: return False self._cache.setdefault(key, item) @@ -349,11 +348,11 @@ class SimpleCache(BaseCache): except KeyError: return False -_test_memcached_key = re.compile(r'[^\x00-\x21\xff]{1,250}$').match +_test_memcached_key = re.compile(r"[^\x00-\x21\xff]{1,250}$").match -class MemcachedCache(BaseCache): +class MemcachedCache(BaseCache): """A cache that uses memcached as backend. The first argument can either be an object that resembles the API of a @@ -392,10 +391,10 @@ class MemcachedCache(BaseCache): BaseCache.__init__(self, default_timeout) if servers is None or isinstance(servers, (list, tuple)): if servers is None: - servers = ['127.0.0.1:11211'] + servers = ["127.0.0.1:11211"] self._client = self.import_preferred_memcache_lib(servers) if self._client is None: - raise RuntimeError('no memcache module found') + raise RuntimeError("no memcache module found") else: # NOTE: servers is actually an already initialized memcache # client. @@ -404,7 +403,7 @@ class MemcachedCache(BaseCache): self.key_prefix = to_native(key_prefix) def _normalize_key(self, key): - key = to_native(key, 'utf-8') + key = to_native(key, "utf-8") if self.key_prefix: key = self.key_prefix + key return key @@ -484,7 +483,7 @@ class MemcachedCache(BaseCache): def has(self, key): key = self._normalize_key(key) if _test_memcached_key(key): - return self._client.append(key, '') + return self._client.append(key, "") return False def clear(self): @@ -534,7 +533,6 @@ GAEMemcachedCache = MemcachedCache class RedisCache(BaseCache): - """Uses the Redis key-value store as a cache backend. The first argument can be either a string denoting address of the Redis @@ -570,24 +568,32 @@ class RedisCache(BaseCache): Any additional keyword arguments will be passed to ``redis.Redis``. """ - def __init__(self, host='localhost', port=6379, password=None, - db=0, default_timeout=300, key_prefix=None, **kwargs): + def __init__( + self, + host="localhost", + port=6379, + password=None, + db=0, + default_timeout=300, + key_prefix=None, + **kwargs + ): BaseCache.__init__(self, default_timeout) if host is None: - raise ValueError('RedisCache host parameter may not be None') + raise ValueError("RedisCache host parameter may not be None") if isinstance(host, string_types): try: import redis except ImportError: - raise RuntimeError('no redis module found') - if kwargs.get('decode_responses', None): - raise ValueError('decode_responses is not supported by ' - 'RedisCache.') - self._client = redis.Redis(host=host, port=port, password=password, - db=db, **kwargs) + raise RuntimeError("no redis module found") + if kwargs.get("decode_responses", None): + raise ValueError("decode_responses is not supported by RedisCache.") + self._client = redis.Redis( + host=host, port=port, password=password, db=db, **kwargs + ) else: self._client = host - self.key_prefix = key_prefix or '' + self.key_prefix = key_prefix or "" def _normalize_timeout(self, timeout): timeout = BaseCache._normalize_timeout(self, timeout) @@ -601,8 +607,8 @@ class RedisCache(BaseCache): """ t = type(value) if t in integer_types: - return str(value).encode('ascii') - return b'!' + pickle.dumps(value) + return str(value).encode("ascii") + return b"!" + pickle.dumps(value) def load_object(self, value): """The reversal of :meth:`dump_object`. This might be called with @@ -610,7 +616,7 @@ class RedisCache(BaseCache): """ if value is None: return None - if value.startswith(b'!'): + if value.startswith(b"!"): try: return pickle.loads(value[1:]) except pickle.PickleError: @@ -633,20 +639,19 @@ class RedisCache(BaseCache): timeout = self._normalize_timeout(timeout) dump = self.dump_object(value) if timeout == -1: - result = self._client.set(name=self.key_prefix + key, - value=dump) + result = self._client.set(name=self.key_prefix + key, value=dump) else: - result = self._client.setex(name=self.key_prefix + key, - value=dump, time=timeout) + result = self._client.setex( + name=self.key_prefix + key, value=dump, time=timeout + ) return result def add(self, key, value, timeout=None): timeout = self._normalize_timeout(timeout) dump = self.dump_object(value) - return ( - self._client.setnx(name=self.key_prefix + key, value=dump) - and self._client.expire(name=self.key_prefix + key, time=timeout) - ) + return self._client.setnx( + name=self.key_prefix + key, value=dump + ) and self._client.expire(name=self.key_prefix + key, time=timeout) def set_many(self, mapping, timeout=None): timeout = self._normalize_timeout(timeout) @@ -659,8 +664,7 @@ class RedisCache(BaseCache): if timeout == -1: pipe.set(name=self.key_prefix + key, value=dump) else: - pipe.setex(name=self.key_prefix + key, value=dump, - time=timeout) + pipe.setex(name=self.key_prefix + key, value=dump, time=timeout) return pipe.execute() def delete(self, key): @@ -679,7 +683,7 @@ class RedisCache(BaseCache): def clear(self): status = False if self.key_prefix: - keys = self._client.keys(self.key_prefix + '*') + keys = self._client.keys(self.key_prefix + "*") if keys: status = self._client.delete(*keys) else: @@ -694,7 +698,6 @@ class RedisCache(BaseCache): class FileSystemCache(BaseCache): - """A cache that stores the items on the file system. This cache depends on being the only user of the `cache_dir`. Make absolutely sure that nobody but this cache stores files there or otherwise the cache will @@ -711,12 +714,11 @@ class FileSystemCache(BaseCache): """ #: used for temporary files by the FileSystemCache - _fs_transaction_suffix = '.__wz_cache' + _fs_transaction_suffix = ".__wz_cache" #: keep amount of files in a cache element - _fs_count_file = '__wz_cache_count' + _fs_count_file = "__wz_cache_count" - def __init__(self, cache_dir, threshold=500, default_timeout=300, - mode=0o600): + def __init__(self, cache_dir, threshold=500, default_timeout=300, mode=0o600): BaseCache.__init__(self, default_timeout) self._path = cache_dir self._threshold = threshold @@ -754,11 +756,14 @@ class FileSystemCache(BaseCache): def _list_dir(self): """return a list of (fully qualified) cache filenames """ - mgmt_files = [self._get_filename(name).split('/')[-1] - for name in (self._fs_count_file,)] - return [os.path.join(self._path, fn) for fn in os.listdir(self._path) - if not fn.endswith(self._fs_transaction_suffix) - and fn not in mgmt_files] + mgmt_files = [ + self._get_filename(name).split("/")[-1] for name in (self._fs_count_file,) + ] + return [ + os.path.join(self._path, fn) + for fn in os.listdir(self._path) + if not fn.endswith(self._fs_transaction_suffix) and fn not in mgmt_files + ] def _prune(self): if self._threshold == 0 or not self._file_count > self._threshold: @@ -769,7 +774,7 @@ class FileSystemCache(BaseCache): for idx, fname in enumerate(entries): try: remove = False - with open(fname, 'rb') as f: + with open(fname, "rb") as f: expires = pickle.load(f) remove = (expires != 0 and expires <= now) or idx % 3 == 0 @@ -791,14 +796,14 @@ class FileSystemCache(BaseCache): def _get_filename(self, key): if isinstance(key, text_type): - key = key.encode('utf-8') # XXX unicode review + key = key.encode("utf-8") # XXX unicode review hash = md5(key).hexdigest() return os.path.join(self._path, hash) def get(self, key): filename = self._get_filename(key) try: - with open(filename, 'rb') as f: + with open(filename, "rb") as f: pickle_time = pickle.load(f) if pickle_time == 0 or pickle_time >= time(): return pickle.load(f) @@ -826,9 +831,10 @@ class FileSystemCache(BaseCache): timeout = self._normalize_timeout(timeout) filename = self._get_filename(key) try: - fd, tmp = tempfile.mkstemp(suffix=self._fs_transaction_suffix, - dir=self._path) - with os.fdopen(fd, 'wb') as f: + fd, tmp = tempfile.mkstemp( + suffix=self._fs_transaction_suffix, dir=self._path + ) + with os.fdopen(fd, "wb") as f: pickle.dump(timeout, f, 1) pickle.dump(value, f, pickle.HIGHEST_PROTOCOL) rename(tmp, filename) @@ -855,7 +861,7 @@ class FileSystemCache(BaseCache): def has(self, key): filename = self._get_filename(key) try: - with open(filename, 'rb') as f: + with open(filename, "rb") as f: pickle_time = pickle.load(f) if pickle_time == 0 or pickle_time >= time(): return True @@ -867,7 +873,7 @@ class FileSystemCache(BaseCache): class UWSGICache(BaseCache): - """ Implements the cache using uWSGI's caching framework. + """Implements the cache using uWSGI's caching framework. .. note:: This class cannot be used when running under PyPy, because the uWSGI @@ -880,19 +886,24 @@ class UWSGICache(BaseCache): same instance as the werkzeug app, you only have to provide the name of the cache. """ - def __init__(self, default_timeout=300, cache=''): + + def __init__(self, default_timeout=300, cache=""): BaseCache.__init__(self, default_timeout) - if platform.python_implementation() == 'PyPy': - raise RuntimeError("uWSGI caching does not work under PyPy, see " - "the docs for more details.") + if platform.python_implementation() == "PyPy": + raise RuntimeError( + "uWSGI caching does not work under PyPy, see " + "the docs for more details." + ) try: import uwsgi + self._uwsgi = uwsgi except ImportError: - raise RuntimeError("uWSGI could not be imported, are you " - "running under uWSGI?") + raise RuntimeError( + "uWSGI could not be imported, are you running under uWSGI?" + ) self.cache = cache @@ -906,14 +917,14 @@ class UWSGICache(BaseCache): return self._uwsgi.cache_del(key, self.cache) def set(self, key, value, timeout=None): - return self._uwsgi.cache_update(key, pickle.dumps(value), - self._normalize_timeout(timeout), - self.cache) + return self._uwsgi.cache_update( + key, pickle.dumps(value), self._normalize_timeout(timeout), self.cache + ) def add(self, key, value, timeout=None): - return self._uwsgi.cache_set(key, pickle.dumps(value), - self._normalize_timeout(timeout), - self.cache) + return self._uwsgi.cache_set( + key, pickle.dumps(value), self._normalize_timeout(timeout), self.cache + ) def clear(self): return self._uwsgi.cache_clear(self.cache) diff --git a/src/werkzeug/contrib/fixers.py b/src/werkzeug/contrib/fixers.py index ad386f57..8df0afda 100644 --- a/src/werkzeug/contrib/fixers.py +++ b/src/werkzeug/contrib/fixers.py @@ -28,16 +28,18 @@ This module includes various helpers that fix web server behavior. """ import warnings +from ..datastructures import Headers +from ..datastructures import ResponseCacheControl +from ..http import parse_cache_control_header +from ..http import parse_options_header +from ..http import parse_set_header +from ..middleware.proxy_fix import ProxyFix as _ProxyFix +from ..useragents import UserAgent + try: - from urllib import unquote -except ImportError: from urllib.parse import unquote - -from werkzeug.http import parse_options_header, parse_cache_control_header, \ - parse_set_header -from werkzeug.useragents import UserAgent -from werkzeug.datastructures import Headers, ResponseCacheControl -from werkzeug.middleware.proxy_fix import ProxyFix as _ProxyFix +except ImportError: + from urllib import unquote class CGIRootFix(object): @@ -57,7 +59,7 @@ class CGIRootFix(object): ``LighttpdCGIRootFix``. """ - def __init__(self, app, app_root='/'): + def __init__(self, app, app_root="/"): warnings.warn( "'CGIRootFix' is deprecated as of version 0.15 and will be" " removed in version 1.0.", @@ -68,7 +70,7 @@ class CGIRootFix(object): self.app_root = app_root.strip("/") def __call__(self, environ, start_response): - environ['SCRIPT_NAME'] = self.app_root + environ["SCRIPT_NAME"] = self.app_root return self.app(environ, start_response) @@ -84,7 +86,6 @@ class LighttpdCGIRootFix(CGIRootFix): class PathInfoFromRequestUriFix(object): - """On windows environment variables are limited to the system charset which makes it impossible to store the `PATH_INFO` variable in the environment without loss of information on some systems. @@ -112,14 +113,13 @@ class PathInfoFromRequestUriFix(object): self.app = app def __call__(self, environ, start_response): - for key in 'REQUEST_URL', 'REQUEST_URI', 'UNENCODED_URL': + for key in "REQUEST_URL", "REQUEST_URI", "UNENCODED_URL": if key not in environ: continue request_uri = unquote(environ[key]) - script_name = unquote(environ.get('SCRIPT_NAME', '')) + script_name = unquote(environ.get("SCRIPT_NAME", "")) if request_uri.startswith(script_name): - environ['PATH_INFO'] = request_uri[len(script_name):] \ - .split('?', 1)[0] + environ["PATH_INFO"] = request_uri[len(script_name) :].split("?", 1)[0] break return self.app(environ, start_response) @@ -144,7 +144,6 @@ class ProxyFix(_ProxyFix): class HeaderRewriterFix(object): - """This middleware can remove response headers and add others. This is for example useful to remove the `Date` header from responses if you are using a server that adds that header, no matter if it's present or @@ -182,11 +181,11 @@ class HeaderRewriterFix(object): new_headers.append((key, value)) new_headers += self.add_headers return start_response(status, new_headers, exc_info) + return self.app(environ, rewriting_start_response) class InternetExplorerFix(object): - """This middleware fixes a couple of bugs with Microsoft Internet Explorer. Currently the following fixes are applied: @@ -224,40 +223,40 @@ class InternetExplorerFix(object): def fix_headers(self, environ, headers, status=None): if self.fix_vary: - header = headers.get('content-type', '') + header = headers.get("content-type", "") mimetype, options = parse_options_header(header) - if mimetype not in ('text/html', 'text/plain', 'text/sgml'): - headers.pop('vary', None) + if mimetype not in ("text/html", "text/plain", "text/sgml"): + headers.pop("vary", None) - if self.fix_attach and 'content-disposition' in headers: - pragma = parse_set_header(headers.get('pragma', '')) - pragma.discard('no-cache') + if self.fix_attach and "content-disposition" in headers: + pragma = parse_set_header(headers.get("pragma", "")) + pragma.discard("no-cache") header = pragma.to_header() if not header: - headers.pop('pragma', '') + headers.pop("pragma", "") else: - headers['Pragma'] = header - header = headers.get('cache-control', '') + headers["Pragma"] = header + header = headers.get("cache-control", "") if header: - cc = parse_cache_control_header(header, - cls=ResponseCacheControl) + cc = parse_cache_control_header(header, cls=ResponseCacheControl) cc.no_cache = None cc.no_store = False header = cc.to_header() if not header: - headers.pop('cache-control', '') + headers.pop("cache-control", "") else: - headers['Cache-Control'] = header + headers["Cache-Control"] = header def run_fixed(self, environ, start_response): def fixing_start_response(status, headers, exc_info=None): headers = Headers(headers) self.fix_headers(environ, headers, status) return start_response(status, headers.to_wsgi_list(), exc_info) + return self.app(environ, fixing_start_response) def __call__(self, environ, start_response): ua = UserAgent(environ) - if ua.browser != 'msie': + if ua.browser != "msie": return self.app(environ, start_response) return self.run_fixed(environ, start_response) diff --git a/src/werkzeug/contrib/iterio.py b/src/werkzeug/contrib/iterio.py index 41e1e20e..b6724540 100644 --- a/src/werkzeug/contrib/iterio.py +++ b/src/werkzeug/contrib/iterio.py @@ -41,13 +41,13 @@ r""" """ import warnings +from .._compat import implements_iterator + try: import greenlet except ImportError: greenlet = None -from werkzeug._compat import implements_iterator - warnings.warn( "'werkzeug.contrib.iterio' is deprecated as of version 0.15 and" " will be removed in version 1.0.", @@ -61,19 +61,18 @@ def _mixed_join(iterable, sentinel): iterator = iter(iterable) first_item = next(iterator, sentinel) if isinstance(first_item, bytes): - return first_item + b''.join(iterator) - return first_item + u''.join(iterator) + return first_item + b"".join(iterator) + return first_item + u"".join(iterator) def _newline(reference_string): if isinstance(reference_string, bytes): - return b'\n' - return u'\n' + return b"\n" + return u"\n" @implements_iterator class IterIO(object): - """Instances of this object implement an interface compatible with the standard Python :class:`file` object. Streams are either read-only or write-only depending on how the object is created. @@ -100,7 +99,7 @@ class IterIO(object): `sentinel` parameter was added. """ - def __new__(cls, obj, sentinel=''): + def __new__(cls, obj, sentinel=""): try: iterator = iter(obj) except TypeError: @@ -112,53 +111,53 @@ class IterIO(object): def tell(self): if self.closed: - raise ValueError('I/O operation on closed file') + raise ValueError("I/O operation on closed file") return self.pos def isatty(self): if self.closed: - raise ValueError('I/O operation on closed file') + raise ValueError("I/O operation on closed file") return False def seek(self, pos, mode=0): if self.closed: - raise ValueError('I/O operation on closed file') - raise IOError(9, 'Bad file descriptor') + raise ValueError("I/O operation on closed file") + raise IOError(9, "Bad file descriptor") def truncate(self, size=None): if self.closed: - raise ValueError('I/O operation on closed file') - raise IOError(9, 'Bad file descriptor') + raise ValueError("I/O operation on closed file") + raise IOError(9, "Bad file descriptor") def write(self, s): if self.closed: - raise ValueError('I/O operation on closed file') - raise IOError(9, 'Bad file descriptor') + raise ValueError("I/O operation on closed file") + raise IOError(9, "Bad file descriptor") def writelines(self, list): if self.closed: - raise ValueError('I/O operation on closed file') - raise IOError(9, 'Bad file descriptor') + raise ValueError("I/O operation on closed file") + raise IOError(9, "Bad file descriptor") def read(self, n=-1): if self.closed: - raise ValueError('I/O operation on closed file') - raise IOError(9, 'Bad file descriptor') + raise ValueError("I/O operation on closed file") + raise IOError(9, "Bad file descriptor") def readlines(self, sizehint=0): if self.closed: - raise ValueError('I/O operation on closed file') - raise IOError(9, 'Bad file descriptor') + raise ValueError("I/O operation on closed file") + raise IOError(9, "Bad file descriptor") def readline(self, length=None): if self.closed: - raise ValueError('I/O operation on closed file') - raise IOError(9, 'Bad file descriptor') + raise ValueError("I/O operation on closed file") + raise IOError(9, "Bad file descriptor") def flush(self): if self.closed: - raise ValueError('I/O operation on closed file') - raise IOError(9, 'Bad file descriptor') + raise ValueError("I/O operation on closed file") + raise IOError(9, "Bad file descriptor") def __next__(self): if self.closed: @@ -170,12 +169,11 @@ class IterIO(object): class IterI(IterIO): - """Convert an stream into an iterator.""" - def __new__(cls, func, sentinel=''): + def __new__(cls, func, sentinel=""): if greenlet is None: - raise RuntimeError('IterI requires greenlet support') + raise RuntimeError("IterI requires greenlet support") stream = object.__new__(cls) stream._parent = greenlet.getcurrent() stream._buffer = [] @@ -201,7 +199,7 @@ class IterI(IterIO): def write(self, s): if self.closed: - raise ValueError('I/O operation on closed file') + raise ValueError("I/O operation on closed file") if s: self.pos += len(s) self._buffer.append(s) @@ -212,7 +210,7 @@ class IterI(IterIO): def flush(self): if self.closed: - raise ValueError('I/O operation on closed file') + raise ValueError("I/O operation on closed file") self._flush_impl() def _flush_impl(self): @@ -225,10 +223,9 @@ class IterI(IterIO): class IterO(IterIO): - """Iter output. Wrap an iterator and give it a stream like interface.""" - def __new__(cls, gen, sentinel=''): + def __new__(cls, gen, sentinel=""): self = object.__new__(cls) self._gen = gen self._buf = None @@ -241,8 +238,8 @@ class IterO(IterIO): return self def _buf_append(self, string): - '''Replace string directly without appending to an empty string, - avoiding type issues.''' + """Replace string directly without appending to an empty string, + avoiding type issues.""" if not self._buf: self._buf = string else: @@ -251,12 +248,12 @@ class IterO(IterIO): def close(self): if not self.closed: self.closed = True - if hasattr(self._gen, 'close'): + if hasattr(self._gen, "close"): self._gen.close() def seek(self, pos, mode=0): if self.closed: - raise ValueError('I/O operation on closed file') + raise ValueError("I/O operation on closed file") if mode == 1: pos += self.pos elif mode == 2: @@ -264,10 +261,10 @@ class IterO(IterIO): self.pos = min(self.pos, self.pos + pos) return elif mode != 0: - raise IOError('Invalid argument') + raise IOError("Invalid argument") buf = [] try: - tmp_end_pos = len(self._buf or '') + tmp_end_pos = len(self._buf or "") while pos > tmp_end_pos: item = next(self._gen) tmp_end_pos += len(item) @@ -280,10 +277,10 @@ class IterO(IterIO): def read(self, n=-1): if self.closed: - raise ValueError('I/O operation on closed file') + raise ValueError("I/O operation on closed file") if n < 0: self._buf_append(_mixed_join(self._gen, self.sentinel)) - result = self._buf[self.pos:] + result = self._buf[self.pos :] self.pos += len(result) return result new_pos = self.pos + n @@ -304,13 +301,13 @@ class IterO(IterIO): new_pos = max(0, new_pos) try: - return self._buf[self.pos:new_pos] + return self._buf[self.pos : new_pos] finally: self.pos = min(new_pos, len(self._buf)) def readline(self, length=None): if self.closed: - raise ValueError('I/O operation on closed file') + raise ValueError("I/O operation on closed file") nl_pos = -1 if self._buf: @@ -344,7 +341,7 @@ class IterO(IterIO): if length is not None and self.pos + length < new_pos: new_pos = self.pos + length try: - return self._buf[self.pos:new_pos] + return self._buf[self.pos : new_pos] finally: self.pos = min(new_pos, len(self._buf)) diff --git a/src/werkzeug/contrib/securecookie.py b/src/werkzeug/contrib/securecookie.py index b9ba20ed..c4c9eee2 100644 --- a/src/werkzeug/contrib/securecookie.py +++ b/src/werkzeug/contrib/securecookie.py @@ -88,19 +88,22 @@ r""" :copyright: 2007 Pallets :license: BSD-3-Clause """ -import pickle import base64 +import pickle import warnings +from hashlib import sha1 as _default_hash from hmac import new as hmac from time import time -from hashlib import sha1 as _default_hash -from werkzeug._compat import iteritems, text_type, to_bytes -from werkzeug.urls import url_quote_plus, url_unquote_plus -from werkzeug._internal import _date_to_unix -from werkzeug.contrib.sessions import ModificationTrackingDict -from werkzeug.security import safe_str_cmp -from werkzeug._compat import to_native +from .._compat import iteritems +from .._compat import text_type +from .._compat import to_bytes +from .._compat import to_native +from .._internal import _date_to_unix +from ..contrib.sessions import ModificationTrackingDict +from ..security import safe_str_cmp +from ..urls import url_quote_plus +from ..urls import url_unquote_plus warnings.warn( "'werkzeug.contrib.securecookie' is deprecated as of version 0.15" @@ -112,12 +115,10 @@ warnings.warn( class UnquoteError(Exception): - """Internal exception used to signal failures on quoting.""" class SecureCookie(ModificationTrackingDict): - """Represents a secure cookie. You can subclass this class and provide an alternative mac method. The import thing is that the mac method is a function with a similar interface to the hashlib. Required @@ -164,7 +165,7 @@ class SecureCookie(ModificationTrackingDict): # explicitly convert it into a bytestring because python 2.6 # no longer performs an implicit string conversion on hmac if secret_key is not None: - secret_key = to_bytes(secret_key, 'utf-8') + secret_key = to_bytes(secret_key, "utf-8") self.secret_key = secret_key self.new = new @@ -178,10 +179,10 @@ class SecureCookie(ModificationTrackingDict): ) def __repr__(self): - return '<%s %s%s>' % ( + return "<%s %s%s>" % ( self.__class__.__name__, dict.__repr__(self), - '*' if self.should_save else '' + "*" if self.should_save else "", ) @property @@ -201,7 +202,7 @@ class SecureCookie(ModificationTrackingDict): if cls.serialization_method is not None: value = cls.serialization_method.dumps(value) if cls.quote_base64: - value = b''.join( + value = b"".join( base64.b64encode(to_bytes(value, "utf8")).splitlines() ).strip() return value @@ -236,21 +237,19 @@ class SecureCookie(ModificationTrackingDict): :class:`datetime.datetime` object) """ if self.secret_key is None: - raise RuntimeError('no secret key defined') + raise RuntimeError("no secret key defined") if expires: - self['_expires'] = _date_to_unix(expires) + self["_expires"] = _date_to_unix(expires) result = [] mac = hmac(self.secret_key, None, self.hash_method) for key, value in sorted(self.items()): - result.append(('%s=%s' % ( - url_quote_plus(key), - self.quote(value).decode('ascii') - )).encode('ascii')) - mac.update(b'|' + result[-1]) - return b'?'.join([ - base64.b64encode(mac.digest()).strip(), - b'&'.join(result) - ]) + result.append( + ( + "%s=%s" % (url_quote_plus(key), self.quote(value).decode("ascii")) + ).encode("ascii") + ) + mac.update(b"|" + result[-1]) + return b"?".join([base64.b64encode(mac.digest()).strip(), b"&".join(result)]) @classmethod def unserialize(cls, string, secret_key): @@ -261,24 +260,24 @@ class SecureCookie(ModificationTrackingDict): :return: a new :class:`SecureCookie`. """ if isinstance(string, text_type): - string = string.encode('utf-8', 'replace') + string = string.encode("utf-8", "replace") if isinstance(secret_key, text_type): - secret_key = secret_key.encode('utf-8', 'replace') + secret_key = secret_key.encode("utf-8", "replace") try: - base64_hash, data = string.split(b'?', 1) + base64_hash, data = string.split(b"?", 1) except (ValueError, IndexError): items = () else: items = {} mac = hmac(secret_key, None, cls.hash_method) - for item in data.split(b'&'): - mac.update(b'|' + item) - if b'=' not in item: + for item in data.split(b"&"): + mac.update(b"|" + item) + if b"=" not in item: items = None break - key, value = item.split(b'=', 1) + key, value = item.split(b"=", 1) # try to make the key a string - key = url_unquote_plus(key.decode('ascii')) + key = url_unquote_plus(key.decode("ascii")) try: key = to_native(key) except UnicodeError: @@ -298,17 +297,17 @@ class SecureCookie(ModificationTrackingDict): except UnquoteError: items = () else: - if '_expires' in items: - if time() > items['_expires']: + if "_expires" in items: + if time() > items["_expires"]: items = () else: - del items['_expires'] + del items["_expires"] else: items = () return cls(items, secret_key, False) @classmethod - def load_cookie(cls, request, key='session', secret_key=None): + def load_cookie(cls, request, key="session", secret_key=None): """Loads a :class:`SecureCookie` from a cookie in request. If the cookie is not set, a new :class:`SecureCookie` instanced is returned. @@ -325,9 +324,19 @@ class SecureCookie(ModificationTrackingDict): return cls(secret_key=secret_key) return cls.unserialize(data, secret_key) - def save_cookie(self, response, key='session', expires=None, - session_expires=None, max_age=None, path='/', domain=None, - secure=None, httponly=False, force=False): + def save_cookie( + self, + response, + key="session", + expires=None, + session_expires=None, + max_age=None, + path="/", + domain=None, + secure=None, + httponly=False, + force=False, + ): """Saves the SecureCookie in a cookie on response object. All parameters that are not described here are forwarded directly to :meth:`~BaseResponse.set_cookie`. @@ -341,6 +350,13 @@ class SecureCookie(ModificationTrackingDict): """ if force or self.should_save: data = self.serialize(session_expires or expires) - response.set_cookie(key, data, expires=expires, max_age=max_age, - path=path, domain=domain, secure=secure, - httponly=httponly) + response.set_cookie( + key, + data, + expires=expires, + max_age=max_age, + path=path, + domain=domain, + secure=secure, + httponly=httponly, + ) diff --git a/src/werkzeug/contrib/sessions.py b/src/werkzeug/contrib/sessions.py index 302fd11c..866e827c 100644 --- a/src/werkzeug/contrib/sessions.py +++ b/src/werkzeug/contrib/sessions.py @@ -51,22 +51,26 @@ r""" :copyright: 2007 Pallets :license: BSD-3-Clause """ -import re import os +import re import tempfile import warnings +from hashlib import sha1 from os import path -from time import time +from pickle import dump +from pickle import HIGHEST_PROTOCOL +from pickle import load from random import random -from hashlib import sha1 -from pickle import dump, load, HIGHEST_PROTOCOL +from time import time -from werkzeug.datastructures import CallbackDict -from werkzeug.utils import dump_cookie, parse_cookie -from werkzeug.wsgi import ClosingIterator -from werkzeug.posixemulation import rename -from werkzeug._compat import PY2, text_type -from werkzeug.filesystem import get_filesystem_encoding +from .._compat import PY2 +from .._compat import text_type +from ..datastructures import CallbackDict +from ..filesystem import get_filesystem_encoding +from ..posixemulation import rename +from ..utils import dump_cookie +from ..utils import parse_cookie +from ..wsgi import ClosingIterator warnings.warn( "'werkzeug.contrib.sessions' is deprecated as of version 0.15 and" @@ -76,31 +80,28 @@ warnings.warn( stacklevel=2, ) -_sha1_re = re.compile(r'^[a-f0-9]{40}$') +_sha1_re = re.compile(r"^[a-f0-9]{40}$") def _urandom(): - if hasattr(os, 'urandom'): + if hasattr(os, "urandom"): return os.urandom(30) - return text_type(random()).encode('ascii') + return text_type(random()).encode("ascii") def generate_key(salt=None): if salt is None: - salt = repr(salt).encode('ascii') - return sha1(b''.join([ - salt, - str(time()).encode('ascii'), - _urandom() - ])).hexdigest() + salt = repr(salt).encode("ascii") + return sha1(b"".join([salt, str(time()).encode("ascii"), _urandom()])).hexdigest() class ModificationTrackingDict(CallbackDict): - __slots__ = ('modified',) + __slots__ = ("modified",) def __init__(self, *args, **kwargs): def on_update(self): self.modified = True + self.modified = False CallbackDict.__init__(self, on_update=on_update) dict.update(self, *args, **kwargs) @@ -120,12 +121,12 @@ class ModificationTrackingDict(CallbackDict): class Session(ModificationTrackingDict): - """Subclass of a dict that keeps track of direct object changes. Changes in mutable structures are not tracked, for those you have to set `modified` to `True` by hand. """ - __slots__ = ModificationTrackingDict.__slots__ + ('sid', 'new') + + __slots__ = ModificationTrackingDict.__slots__ + ("sid", "new") def __init__(self, data, sid, new=False): ModificationTrackingDict.__init__(self, data) @@ -133,10 +134,10 @@ class Session(ModificationTrackingDict): self.new = new def __repr__(self): - return '<%s %s%s>' % ( + return "<%s %s%s>" % ( self.__class__.__name__, dict.__repr__(self), - '*' if self.should_save else '' + "*" if self.should_save else "", ) @property @@ -151,7 +152,6 @@ class Session(ModificationTrackingDict): class SessionStore(object): - """Baseclass for all session stores. The Werkzeug contrib module does not implement any useful stores besides the filesystem store, application developers are encouraged to create their own stores. @@ -197,11 +197,10 @@ class SessionStore(object): #: used for temporary files by the filesystem session store -_fs_transaction_suffix = '.__wz_sess' +_fs_transaction_suffix = ".__wz_sess" class FilesystemSessionStore(SessionStore): - """Simple example session store that saves sessions on the filesystem. This store works best on POSIX systems and Windows Vista / Windows Server 2008 and newer. @@ -223,17 +222,23 @@ class FilesystemSessionStore(SessionStore): not yet saved. """ - def __init__(self, path=None, filename_template='werkzeug_%s.sess', - session_class=None, renew_missing=False, mode=0o644): + def __init__( + self, + path=None, + filename_template="werkzeug_%s.sess", + session_class=None, + renew_missing=False, + mode=0o644, + ): SessionStore.__init__(self, session_class) if path is None: path = tempfile.gettempdir() self.path = path if isinstance(filename_template, text_type) and PY2: - filename_template = filename_template.encode( - get_filesystem_encoding()) - assert not filename_template.endswith(_fs_transaction_suffix), \ - 'filename templates may not end with %s' % _fs_transaction_suffix + filename_template = filename_template.encode(get_filesystem_encoding()) + assert not filename_template.endswith(_fs_transaction_suffix), ( + "filename templates may not end with %s" % _fs_transaction_suffix + ) self.filename_template = filename_template self.renew_missing = renew_missing self.mode = mode @@ -248,9 +253,8 @@ class FilesystemSessionStore(SessionStore): def save(self, session): fn = self.get_session_filename(session.sid) - fd, tmp = tempfile.mkstemp(suffix=_fs_transaction_suffix, - dir=self.path) - f = os.fdopen(fd, 'wb') + fd, tmp = tempfile.mkstemp(suffix=_fs_transaction_suffix, dir=self.path) + f = os.fdopen(fd, "wb") try: dump(dict(session), f, HIGHEST_PROTOCOL) finally: @@ -272,7 +276,7 @@ class FilesystemSessionStore(SessionStore): if not self.is_valid_key(sid): return self.new() try: - f = open(self.get_session_filename(sid), 'rb') + f = open(self.get_session_filename(sid), "rb") except IOError: if self.renew_missing: return self.new() @@ -292,9 +296,10 @@ class FilesystemSessionStore(SessionStore): .. versionadded:: 0.6 """ - before, after = self.filename_template.split('%s', 1) - filename_re = re.compile(r'%s(.{5,})%s$' % (re.escape(before), - re.escape(after))) + before, after = self.filename_template.split("%s", 1) + filename_re = re.compile( + r"%s(.{5,})%s$" % (re.escape(before), re.escape(after)) + ) result = [] for filename in os.listdir(self.path): #: this is a session that is still being saved. @@ -307,7 +312,6 @@ class FilesystemSessionStore(SessionStore): class SessionMiddleware(object): - """A simple middleware that puts the session object of a store provided into the WSGI environ. It automatically sets cookies and restores sessions. @@ -323,11 +327,20 @@ class SessionMiddleware(object): compatibility. """ - def __init__(self, app, store, cookie_name='session_id', - cookie_age=None, cookie_expires=None, cookie_path='/', - cookie_domain=None, cookie_secure=None, - cookie_httponly=False, cookie_samesite='Lax', - environ_key='werkzeug.session'): + def __init__( + self, + app, + store, + cookie_name="session_id", + cookie_age=None, + cookie_expires=None, + cookie_path="/", + cookie_domain=None, + cookie_secure=None, + cookie_httponly=False, + cookie_samesite="Lax", + environ_key="werkzeug.session", + ): self.app = app self.store = store self.cookie_name = cookie_name @@ -341,7 +354,7 @@ class SessionMiddleware(object): self.environ_key = environ_key def __call__(self, environ, start_response): - cookie = parse_cookie(environ.get('HTTP_COOKIE', '')) + cookie = parse_cookie(environ.get("HTTP_COOKIE", "")) sid = cookie.get(self.cookie_name, None) if sid is None: session = self.store.new() @@ -352,12 +365,25 @@ class SessionMiddleware(object): def injecting_start_response(status, headers, exc_info=None): if session.should_save: self.store.save(session) - headers.append(('Set-Cookie', dump_cookie(self.cookie_name, - session.sid, self.cookie_age, - self.cookie_expires, self.cookie_path, - self.cookie_domain, self.cookie_secure, - self.cookie_httponly, - samesite=self.cookie_samesite))) + headers.append( + ( + "Set-Cookie", + dump_cookie( + self.cookie_name, + session.sid, + self.cookie_age, + self.cookie_expires, + self.cookie_path, + self.cookie_domain, + self.cookie_secure, + self.cookie_httponly, + samesite=self.cookie_samesite, + ), + ) + ) return start_response(status, headers, exc_info) - return ClosingIterator(self.app(environ, injecting_start_response), - lambda: self.store.save_if_modified(session)) + + return ClosingIterator( + self.app(environ, injecting_start_response), + lambda: self.store.save_if_modified(session), + ) diff --git a/src/werkzeug/contrib/wrappers.py b/src/werkzeug/contrib/wrappers.py index fa3bc565..49b82a71 100644 --- a/src/werkzeug/contrib/wrappers.py +++ b/src/werkzeug/contrib/wrappers.py @@ -23,11 +23,12 @@ import codecs import warnings -from werkzeug.exceptions import BadRequest -from werkzeug.utils import cached_property -from werkzeug.http import dump_options_header, parse_options_header -from werkzeug._compat import wsgi_decoding_dance -from werkzeug.wrappers.json import JSONMixin as _JSONMixin +from .._compat import wsgi_decoding_dance +from ..exceptions import BadRequest +from ..http import dump_options_header +from ..http import parse_options_header +from ..utils import cached_property +from ..wrappers.json import JSONMixin as _JSONMixin def is_known_charset(charset): @@ -87,8 +88,8 @@ class ProtobufRequestMixin(object): DeprecationWarning, stacklevel=2, ) - if 'protobuf' not in self.environ.get('CONTENT_TYPE', ''): - raise BadRequest('Not a Protobuf request') + if "protobuf" not in self.environ.get("CONTENT_TYPE", ""): + raise BadRequest("Not a Protobuf request") obj = proto_type() try: @@ -108,7 +109,8 @@ class RoutingArgsRequestMixin(object): """This request mixin adds support for the wsgiorg routing args `specification`_. - .. _specification: https://wsgi.readthedocs.io/en/latest/specifications/routing_args.html + .. _specification: https://wsgi.readthedocs.io/en/latest/ + specifications/routing_args.html .. deprecated:: 0.15 This mixin will be removed in version 1.0. @@ -122,7 +124,7 @@ class RoutingArgsRequestMixin(object): DeprecationWarning, stacklevel=2, ) - return self.environ.get('wsgiorg.routing_args', (()))[0] + return self.environ.get("wsgiorg.routing_args", (()))[0] def _set_routing_args(self, value): warnings.warn( @@ -133,13 +135,19 @@ class RoutingArgsRequestMixin(object): stacklevel=2, ) if self.shallow: - raise RuntimeError('A shallow request tried to modify the WSGI ' - 'environment. If you really want to do that, ' - 'set `shallow` to False.') - self.environ['wsgiorg.routing_args'] = (value, self.routing_vars) - - routing_args = property(_get_routing_args, _set_routing_args, doc=''' - The positional URL arguments as `tuple`.''') + raise RuntimeError( + "A shallow request tried to modify the WSGI " + "environment. If you really want to do that, " + "set `shallow` to False." + ) + self.environ["wsgiorg.routing_args"] = (value, self.routing_vars) + + routing_args = property( + _get_routing_args, + _set_routing_args, + doc=""" + The positional URL arguments as `tuple`.""", + ) del _get_routing_args, _set_routing_args def _get_routing_vars(self): @@ -150,7 +158,7 @@ class RoutingArgsRequestMixin(object): DeprecationWarning, stacklevel=2, ) - rv = self.environ.get('wsgiorg.routing_args') + rv = self.environ.get("wsgiorg.routing_args") if rv is not None: return rv[1] rv = {} @@ -167,13 +175,19 @@ class RoutingArgsRequestMixin(object): stacklevel=2, ) if self.shallow: - raise RuntimeError('A shallow request tried to modify the WSGI ' - 'environment. If you really want to do that, ' - 'set `shallow` to False.') - self.environ['wsgiorg.routing_args'] = (self.routing_args, value) - - routing_vars = property(_get_routing_vars, _set_routing_vars, doc=''' - The keyword URL arguments as `dict`.''') + raise RuntimeError( + "A shallow request tried to modify the WSGI " + "environment. If you really want to do that, " + "set `shallow` to False." + ) + self.environ["wsgiorg.routing_args"] = (self.routing_args, value) + + routing_vars = property( + _get_routing_vars, + _set_routing_vars, + doc=""" + The keyword URL arguments as `dict`.""", + ) del _get_routing_vars, _set_routing_vars @@ -216,9 +230,10 @@ class ReverseSlashBehaviorRequestMixin(object): DeprecationWarning, stacklevel=2, ) - path = wsgi_decoding_dance(self.environ.get('PATH_INFO') or '', - self.charset, self.encoding_errors) - return path.lstrip('/') + path = wsgi_decoding_dance( + self.environ.get("PATH_INFO") or "", self.charset, self.encoding_errors + ) + return path.lstrip("/") @cached_property def script_root(self): @@ -230,9 +245,10 @@ class ReverseSlashBehaviorRequestMixin(object): DeprecationWarning, stacklevel=2, ) - path = wsgi_decoding_dance(self.environ.get('SCRIPT_NAME') or '', - self.charset, self.encoding_errors) - return path.rstrip('/') + '/' + path = wsgi_decoding_dance( + self.environ.get("SCRIPT_NAME") or "", self.charset, self.encoding_errors + ) + return path.rstrip("/") + "/" class DynamicCharsetRequestMixin(object): @@ -268,7 +284,7 @@ class DynamicCharsetRequestMixin(object): #: is latin1 which is what HTTP specifies as default charset. #: You may however want to set this to utf-8 to better support #: browsers that do not transmit a charset for incoming data. - default_charset = 'latin1' + default_charset = "latin1" def unknown_charset(self, charset): """Called if a charset was provided but is not supported by @@ -279,7 +295,7 @@ class DynamicCharsetRequestMixin(object): :param charset: the charset that was not found. :return: the replacement charset. """ - return 'latin1' + return "latin1" @cached_property def charset(self): @@ -291,10 +307,10 @@ class DynamicCharsetRequestMixin(object): DeprecationWarning, stacklevel=2, ) - header = self.environ.get('CONTENT_TYPE') + header = self.environ.get("CONTENT_TYPE") if header: ct, options = parse_options_header(header) - charset = options.get('charset') + charset = options.get("charset") if charset: if is_known_charset(charset): return charset @@ -327,7 +343,7 @@ class DynamicCharsetResponseMixin(object): """ #: the default charset. - default_charset = 'utf-8' + default_charset = "utf-8" def _get_charset(self): warnings.warn( @@ -337,9 +353,9 @@ class DynamicCharsetResponseMixin(object): DeprecationWarning, stacklevel=2, ) - header = self.headers.get('content-type') + header = self.headers.get("content-type") if header: - charset = parse_options_header(header)[1].get('charset') + charset = parse_options_header(header)[1].get("charset") if charset: return charset return self.default_charset @@ -352,15 +368,18 @@ class DynamicCharsetResponseMixin(object): DeprecationWarning, stacklevel=2, ) - header = self.headers.get('content-type') + header = self.headers.get("content-type") ct, options = parse_options_header(header) if not ct: - raise TypeError('Cannot set charset if Content-Type ' - 'header is missing.') - options['charset'] = charset - self.headers['Content-Type'] = dump_options_header(ct, options) - - charset = property(_get_charset, _set_charset, doc=""" + raise TypeError("Cannot set charset if Content-Type header is missing.") + options["charset"] = charset + self.headers["Content-Type"] = dump_options_header(ct, options) + + charset = property( + _get_charset, + _set_charset, + doc=""" The charset for the response. It's stored inside the - Content-Type header as a parameter.""") + Content-Type header as a parameter.""", + ) del _get_charset, _set_charset diff --git a/src/werkzeug/datastructures.py b/src/werkzeug/datastructures.py index a39397a2..9643db96 100644 --- a/src/werkzeug/datastructures.py +++ b/src/werkzeug/datastructures.py @@ -8,24 +8,32 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import re import codecs import mimetypes +import re from copy import deepcopy from itertools import repeat -from werkzeug._internal import _missing -from werkzeug._compat import BytesIO, collections_abc, iterkeys, itervalues, \ - iteritems, iterlists, PY2, text_type, integer_types, string_types, \ - make_literal_wrapper, to_native -from werkzeug.filesystem import get_filesystem_encoding - +from ._compat import BytesIO +from ._compat import collections_abc +from ._compat import integer_types +from ._compat import iteritems +from ._compat import iterkeys +from ._compat import iterlists +from ._compat import itervalues +from ._compat import make_literal_wrapper +from ._compat import PY2 +from ._compat import string_types +from ._compat import text_type +from ._compat import to_native +from ._internal import _missing +from .filesystem import get_filesystem_encoding -_locale_delim_re = re.compile(r'[_-]') +_locale_delim_re = re.compile(r"[_-]") def is_immutable(self): - raise TypeError('%r objects are immutable' % self.__class__.__name__) + raise TypeError("%r objects are immutable" % self.__class__.__name__) def iter_multi_items(mapping): @@ -52,18 +60,28 @@ def native_itermethods(names): return lambda x: x def setviewmethod(cls, name): - viewmethod_name = 'view%s' % name - viewmethod = lambda self, *a, **kw: ViewItems(self, name, 'view_%s' % name, *a, **kw) - viewmethod.__doc__ = \ - '"""`%s()` object providing a view on %s"""' % (viewmethod_name, name) + viewmethod_name = "view%s" % name + repr_name = "view_%s" % name + + def viewmethod(self, *a, **kw): + return ViewItems(self, name, repr_name, *a, **kw) + + viewmethod.__name__ = viewmethod_name + viewmethod.__doc__ = "`%s()` object providing a view on %s" % ( + viewmethod_name, + name, + ) setattr(cls, viewmethod_name, viewmethod) def setitermethod(cls, name): itermethod = getattr(cls, name) - setattr(cls, 'iter%s' % name, itermethod) - listmethod = lambda self, *a, **kw: list(itermethod(self, *a, **kw)) - listmethod.__doc__ = \ - 'Like :py:meth:`iter%s`, but returns a list.' % name + setattr(cls, "iter%s" % name, itermethod) + + def listmethod(self, *a, **kw): + return list(itermethod(self, *a, **kw)) + + listmethod.__name__ = name + listmethod.__doc__ = "Like :py:meth:`iter%s`, but returns a list." % name setattr(cls, name, listmethod) def wrap(cls): @@ -71,11 +89,11 @@ def native_itermethods(names): setitermethod(cls, name) setviewmethod(cls, name) return cls + return wrap class ImmutableListMixin(object): - """Makes a :class:`list` immutable. .. versionadded:: 0.5 @@ -99,6 +117,7 @@ class ImmutableListMixin(object): def __iadd__(self, other): is_immutable(self) + __imul__ = __iadd__ def __setitem__(self, key, value): @@ -106,6 +125,7 @@ class ImmutableListMixin(object): def append(self, item): is_immutable(self) + remove = append def extend(self, iterable): @@ -125,7 +145,6 @@ class ImmutableListMixin(object): class ImmutableList(ImmutableListMixin, list): - """An immutable :class:`list`. .. versionadded:: 0.5 @@ -134,20 +153,17 @@ class ImmutableList(ImmutableListMixin, list): """ def __repr__(self): - return '%s(%s)' % ( - self.__class__.__name__, - list.__repr__(self), - ) + return "%s(%s)" % (self.__class__.__name__, list.__repr__(self)) class ImmutableDictMixin(object): - """Makes a :class:`dict` immutable. .. versionadded:: 0.5 :private: """ + _hash_cache = None @classmethod @@ -191,7 +207,6 @@ class ImmutableDictMixin(object): class ImmutableMultiDictMixin(ImmutableDictMixin): - """Makes a :class:`MultiDict` immutable. .. versionadded:: 0.5 @@ -222,7 +237,6 @@ class ImmutableMultiDictMixin(ImmutableDictMixin): class UpdateDictMixin(object): - """Makes dicts call `self.on_update` on modifications. .. versionadded:: 0.5 @@ -232,12 +246,13 @@ class UpdateDictMixin(object): on_update = None - def calls_update(name): + def calls_update(name): # noqa: B902 def oncall(self, *args, **kw): rv = getattr(super(UpdateDictMixin, self), name)(*args, **kw) if self.on_update is not None: self.on_update(self) return rv + oncall.__name__ = name return oncall @@ -258,16 +273,15 @@ class UpdateDictMixin(object): self.on_update(self) return rv - __setitem__ = calls_update('__setitem__') - __delitem__ = calls_update('__delitem__') - clear = calls_update('clear') - popitem = calls_update('popitem') - update = calls_update('update') + __setitem__ = calls_update("__setitem__") + __delitem__ = calls_update("__delitem__") + clear = calls_update("clear") + popitem = calls_update("popitem") + update = calls_update("update") del calls_update class TypeConversionDict(dict): - """Works like a regular dict but the :meth:`get` method can perform type conversions. :class:`MultiDict` and :class:`CombinedMultiDict` are subclasses of this class and provide the same feature. @@ -309,7 +323,6 @@ class TypeConversionDict(dict): class ImmutableTypeConversionDict(ImmutableDictMixin, TypeConversionDict): - """Works like a :class:`TypeConversionDict` but does not support modifications. @@ -328,7 +341,6 @@ class ImmutableTypeConversionDict(ImmutableDictMixin, TypeConversionDict): class ViewItems(object): - def __init__(self, multi_dict, method, repr_name, *a, **kw): self.__multi_dict = multi_dict self.__method = method @@ -340,15 +352,14 @@ class ViewItems(object): return getattr(self.__multi_dict, self.__method)(*self.__a, **self.__kw) def __repr__(self): - return '%s(%r)' % (self.__repr_name, list(self.__get_items())) + return "%s(%r)" % (self.__repr_name, list(self.__get_items())) def __iter__(self): return iter(self.__get_items()) -@native_itermethods(['keys', 'values', 'items', 'lists', 'listvalues']) +@native_itermethods(["keys", "values", "items", "lists", "listvalues"]) class MultiDict(TypeConversionDict): - """A :class:`MultiDict` is a dictionary subclass customized to deal with multiple values for the same key which is for example used by the parsing functions in the wrappers. This is necessary because some HTML form @@ -678,17 +689,17 @@ class MultiDict(TypeConversionDict): return self.deepcopy(memo=memo) def __repr__(self): - return '%s(%r)' % (self.__class__.__name__, list(iteritems(self, multi=True))) + return "%s(%r)" % (self.__class__.__name__, list(iteritems(self, multi=True))) class _omd_bucket(object): - """Wraps values in the :class:`OrderedMultiDict`. This makes it possible to keep an order over multiple different keys. It requires a lot of extra memory and slows down access a lot, but makes it possible to access elements in O(1) and iterate in O(n). """ - __slots__ = ('prev', 'key', 'value', 'next') + + __slots__ = ("prev", "key", "value", "next") def __init__(self, omd, key, value): self.prev = omd._last_bucket @@ -713,9 +724,8 @@ class _omd_bucket(object): omd._last_bucket = self.prev -@native_itermethods(['keys', 'values', 'items', 'lists', 'listvalues']) +@native_itermethods(["keys", "values", "items", "lists", "listvalues"]) class OrderedMultiDict(MultiDict): - """Works like a regular :class:`MultiDict` but preserves the order of the fields. To convert the ordered multi dict into a list you can use the :meth:`items` method and pass it ``multi=True``. @@ -822,7 +832,7 @@ class OrderedMultiDict(MultiDict): ptr = ptr.next def listvalues(self): - for key, values in iterlists(self): + for _key, values in iterlists(self): yield values def add(self, key, value): @@ -849,8 +859,7 @@ class OrderedMultiDict(MultiDict): self.add(key, value) def setlistdefault(self, key, default_list=None): - raise TypeError('setlistdefault is unsupported for ' - 'ordered multi dicts') + raise TypeError("setlistdefault is unsupported for ordered multi dicts") def update(self, mapping): for key, value in iter_multi_items(mapping): @@ -893,21 +902,21 @@ class OrderedMultiDict(MultiDict): def _options_header_vkw(value, kw): - return dump_options_header(value, dict((k.replace('_', '-'), v) - for k, v in kw.items())) + return dump_options_header( + value, dict((k.replace("_", "-"), v) for k, v in kw.items()) + ) def _unicodify_header_value(value): if isinstance(value, bytes): - value = value.decode('latin-1') + value = value.decode("latin-1") if not isinstance(value, text_type): value = text_type(value) return value -@native_itermethods(['keys', 'values', 'items']) +@native_itermethods(["keys", "values", "items"]) class Headers(object): - """An object that stores some headers. It has a dict-like interface but is ordered and can store the same keys multiple times. @@ -968,8 +977,7 @@ class Headers(object): raise exceptions.BadRequestKeyError(key) def __eq__(self, other): - return other.__class__ is self.__class__ and \ - set(other._list) == set(self._list) + return other.__class__ is self.__class__ and set(other._list) == set(self._list) __hash__ = None @@ -1007,7 +1015,7 @@ class Headers(object): except KeyError: return default if as_bytes: - rv = rv.encode('latin1') + rv = rv.encode("latin1") if type is None: return rv try: @@ -1036,7 +1044,7 @@ class Headers(object): for k, v in self: if k.lower() == ikey: if as_bytes: - v = v.encode('latin1') + v = v.encode("latin1") if type is not None: try: v = type(v) @@ -1168,10 +1176,12 @@ class Headers(object): def _validate_value(self, value): if not isinstance(value, text_type): - raise TypeError('Value should be unicode.') - if u'\n' in value or u'\r' in value: - raise ValueError('Detected newline in header value. This is ' - 'a potential security problem') + raise TypeError("Value should be unicode.") + if u"\n" in value or u"\r" in value: + raise ValueError( + "Detected newline in header value. This is " + "a potential security problem" + ) def add_header(self, _key, _value, **_kw): """Add a new header tuple to the list. @@ -1210,7 +1220,7 @@ class Headers(object): return listiter = iter(self._list) ikey = _key.lower() - for idx, (old_key, old_value) in enumerate(listiter): + for idx, (old_key, _old_value) in enumerate(listiter): if old_key.lower() == ikey: # replace first ocurrence self._list[idx] = (_key, _value) @@ -1218,7 +1228,7 @@ class Headers(object): else: self._list.append((_key, _value)) return - self._list[idx + 1:] = [t for t in listiter if t[0].lower() != ikey] + self._list[idx + 1 :] = [t for t in listiter if t[0].lower() != ikey] def setdefault(self, key, default): """Returns the value for the key if it is in the dict, otherwise it @@ -1250,17 +1260,18 @@ class Headers(object): else: self.set(key, value) - def to_list(self, charset='iso-8859-1'): + def to_list(self, charset="iso-8859-1"): """Convert the headers into a list suitable for WSGI. .. deprecated:: 0.9 """ from warnings import warn + warn( "'to_list' deprecated as of version 0.9 and will be removed" " in version 1.0. Use 'to_wsgi_list' instead.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) return self.to_wsgi_list() @@ -1273,7 +1284,7 @@ class Headers(object): :return: list """ if PY2: - return [(to_native(k), v.encode('latin1')) for k, v in self] + return [(to_native(k), v.encode("latin1")) for k, v in self] return list(self) def copy(self): @@ -1286,19 +1297,15 @@ class Headers(object): """Returns formatted headers suitable for HTTP transmission.""" strs = [] for key, value in self.to_wsgi_list(): - strs.append('%s: %s' % (key, value)) - strs.append('\r\n') - return '\r\n'.join(strs) + strs.append("%s: %s" % (key, value)) + strs.append("\r\n") + return "\r\n".join(strs) def __repr__(self): - return '%s(%r)' % ( - self.__class__.__name__, - list(self) - ) + return "%s(%r)" % (self.__class__.__name__, list(self)) class ImmutableHeadersMixin(object): - """Makes a :class:`Headers` immutable. We do not mark them as hashable though since the only usecase for this datastructure in Werkzeug is a view on a mutable structure. @@ -1313,10 +1320,12 @@ class ImmutableHeadersMixin(object): def __setitem__(self, key, value): is_immutable(self) + set = __setitem__ def add(self, item): is_immutable(self) + remove = add_header = add def extend(self, iterable): @@ -1336,7 +1345,6 @@ class ImmutableHeadersMixin(object): class EnvironHeaders(ImmutableHeadersMixin, Headers): - """Read only version of the headers from a WSGI environment. This provides the same interface as `Headers` and is constructed from a WSGI environment. @@ -1360,10 +1368,10 @@ class EnvironHeaders(ImmutableHeadersMixin, Headers): # used because get() calls it. if not isinstance(key, string_types): raise KeyError(key) - key = key.upper().replace('-', '_') - if key in ('CONTENT_TYPE', 'CONTENT_LENGTH'): + key = key.upper().replace("-", "_") + if key in ("CONTENT_TYPE", "CONTENT_LENGTH"): return _unicodify_header_value(self.environ[key]) - return _unicodify_header_value(self.environ['HTTP_' + key]) + return _unicodify_header_value(self.environ["HTTP_" + key]) def __len__(self): # the iter is necessary because otherwise list calls our @@ -1372,21 +1380,23 @@ class EnvironHeaders(ImmutableHeadersMixin, Headers): def __iter__(self): for key, value in iteritems(self.environ): - if key.startswith('HTTP_') and key not in \ - ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'): - yield (key[5:].replace('_', '-').title(), - _unicodify_header_value(value)) - elif key in ('CONTENT_TYPE', 'CONTENT_LENGTH') and value: - yield (key.replace('_', '-').title(), - _unicodify_header_value(value)) + if key.startswith("HTTP_") and key not in ( + "HTTP_CONTENT_TYPE", + "HTTP_CONTENT_LENGTH", + ): + yield ( + key[5:].replace("_", "-").title(), + _unicodify_header_value(value), + ) + elif key in ("CONTENT_TYPE", "CONTENT_LENGTH") and value: + yield (key.replace("_", "-").title(), _unicodify_header_value(value)) def copy(self): - raise TypeError('cannot create %r copies' % self.__class__.__name__) + raise TypeError("cannot create %r copies" % self.__class__.__name__) -@native_itermethods(['keys', 'values', 'items', 'lists', 'listvalues']) +@native_itermethods(["keys", "values", "items", "lists", "listvalues"]) class CombinedMultiDict(ImmutableMultiDictMixin, MultiDict): - """A read only :class:`MultiDict` that you can pass multiple :class:`MultiDict` instances as sequence and it will combine the return values of all wrapped dicts: @@ -1417,8 +1427,7 @@ class CombinedMultiDict(ImmutableMultiDictMixin, MultiDict): @classmethod def fromkeys(cls): - raise TypeError('cannot create %r instances by fromkeys' % - cls.__name__) + raise TypeError("cannot create %r instances by fromkeys" % cls.__name__) def __getitem__(self, key): for d in self.dicts: @@ -1471,7 +1480,7 @@ class CombinedMultiDict(ImmutableMultiDictMixin, MultiDict): yield key, value def values(self): - for key, value in iteritems(self): + for _key, value in iteritems(self): yield value def lists(self): @@ -1523,11 +1532,10 @@ class CombinedMultiDict(ImmutableMultiDictMixin, MultiDict): has_key = __contains__ def __repr__(self): - return '%s(%r)' % (self.__class__.__name__, self.dicts) + return "%s(%r)" % (self.__class__.__name__, self.dicts) class FileMultiDict(MultiDict): - """A special :class:`MultiDict` that has convenience methods to add files to it. This is used for :class:`EnvironBuilder` and generally useful for unittesting. @@ -1550,27 +1558,24 @@ class FileMultiDict(MultiDict): if isinstance(file, string_types): if filename is None: filename = file - file = open(file, 'rb') + file = open(file, "rb") if filename and content_type is None: - content_type = mimetypes.guess_type(filename)[0] or \ - 'application/octet-stream' + content_type = ( + mimetypes.guess_type(filename)[0] or "application/octet-stream" + ) value = FileStorage(file, filename, name, content_type) self.add(name, value) class ImmutableDict(ImmutableDictMixin, dict): - """An immutable :class:`dict`. .. versionadded:: 0.5 """ def __repr__(self): - return '%s(%s)' % ( - self.__class__.__name__, - dict.__repr__(self), - ) + return "%s(%s)" % (self.__class__.__name__, dict.__repr__(self)) def copy(self): """Return a shallow mutable copy of this object. Keep in mind that @@ -1584,7 +1589,6 @@ class ImmutableDict(ImmutableDictMixin, dict): class ImmutableMultiDict(ImmutableMultiDictMixin, MultiDict): - """An immutable :class:`MultiDict`. .. versionadded:: 0.5 @@ -1602,7 +1606,6 @@ class ImmutableMultiDict(ImmutableMultiDictMixin, MultiDict): class ImmutableOrderedMultiDict(ImmutableMultiDictMixin, OrderedMultiDict): - """An immutable :class:`OrderedMultiDict`. .. versionadded:: 0.6 @@ -1622,9 +1625,8 @@ class ImmutableOrderedMultiDict(ImmutableMultiDictMixin, OrderedMultiDict): return self -@native_itermethods(['values']) +@native_itermethods(["values"]) class Accept(ImmutableList): - """An :class:`Accept` object is just a list subclass for lists of ``(value, quality)`` tuples. It is automatically sorted by specificity and quality. @@ -1663,17 +1665,20 @@ class Accept(ImmutableList): list.__init__(self, values) else: self.provided = True - values = sorted(values, key=lambda x: (self._specificity(x[0]), x[1], x[0]), - reverse=True) + values = sorted( + values, + key=lambda x: (self._specificity(x[0]), x[1], x[0]), + reverse=True, + ) list.__init__(self, values) def _specificity(self, value): """Returns a tuple describing the value's specificity.""" - return value != '*', + return (value != "*",) def _value_matches(self, value, item): """Check if a value matches a given accept item.""" - return item == '*' or item.lower() == value.lower() + return item == "*" or item.lower() == value.lower() def __getitem__(self, key): """Besides index lookup (getting item n) you can also pass it a string @@ -1697,15 +1702,15 @@ class Accept(ImmutableList): return 0 def __contains__(self, value): - for item, quality in self: + for item, _quality in self: if self._value_matches(value, item): return True return False def __repr__(self): - return '%s([%s])' % ( + return "%s([%s])" % ( self.__class__.__name__, - ', '.join('(%r, %s)' % (x, y) for x, y in self) + ", ".join("(%r, %s)" % (x, y) for x, y in self), ) def index(self, key): @@ -1718,7 +1723,7 @@ class Accept(ImmutableList): with the list API. """ if isinstance(key, string_types): - for idx, (item, quality) in enumerate(self): + for idx, (item, _quality) in enumerate(self): if self._value_matches(key, item): return idx raise ValueError(key) @@ -1744,9 +1749,9 @@ class Accept(ImmutableList): result = [] for value, quality in self: if quality != 1: - value = '%s;q=%s' % (value, quality) + value = "%s;q=%s" % (value, quality) result.append(value) - return ','.join(result) + return ",".join(result) def __str__(self): return self.to_header() @@ -1791,75 +1796,71 @@ class Accept(ImmutableList): class MIMEAccept(Accept): - """Like :class:`Accept` but with special methods and behavior for mimetypes. """ def _specificity(self, value): - return tuple(x != '*' for x in value.split('/', 1)) + return tuple(x != "*" for x in value.split("/", 1)) def _value_matches(self, value, item): def _normalize(x): x = x.lower() - return ('*', '*') if x == '*' else x.split('/', 1) + return ("*", "*") if x == "*" else x.split("/", 1) # this is from the application which is trusted. to avoid developer # frustration we actually check these for valid values - if '/' not in value: - raise ValueError('invalid mimetype %r' % value) + if "/" not in value: + raise ValueError("invalid mimetype %r" % value) value_type, value_subtype = _normalize(value) - if value_type == '*' and value_subtype != '*': - raise ValueError('invalid mimetype %r' % value) + if value_type == "*" and value_subtype != "*": + raise ValueError("invalid mimetype %r" % value) - if '/' not in item: + if "/" not in item: return False item_type, item_subtype = _normalize(item) - if item_type == '*' and item_subtype != '*': + if item_type == "*" and item_subtype != "*": return False return ( - (item_type == item_subtype == '*' - or value_type == value_subtype == '*') - or (item_type == value_type and (item_subtype == '*' - or value_subtype == '*' - or item_subtype == value_subtype)) + item_type == item_subtype == "*" or value_type == value_subtype == "*" + ) or ( + item_type == value_type + and ( + item_subtype == "*" + or value_subtype == "*" + or item_subtype == value_subtype + ) ) @property def accept_html(self): """True if this object accepts HTML.""" return ( - 'text/html' in self - or 'application/xhtml+xml' in self - or self.accept_xhtml + "text/html" in self or "application/xhtml+xml" in self or self.accept_xhtml ) @property def accept_xhtml(self): """True if this object accepts XHTML.""" - return ( - 'application/xhtml+xml' in self - or 'application/xml' in self - ) + return "application/xhtml+xml" in self or "application/xml" in self @property def accept_json(self): """True if this object accepts JSON.""" - return 'application/json' in self + return "application/json" in self class LanguageAccept(Accept): - """Like :class:`Accept` but with normalization for languages.""" def _value_matches(self, value, item): def _normalize(language): return _locale_delim_re.split(language.lower()) - return item == '*' or _normalize(value) == _normalize(item) + return item == "*" or _normalize(value) == _normalize(item) -class CharsetAccept(Accept): +class CharsetAccept(Accept): """Like :class:`Accept` but with normalization for charsets.""" def _value_matches(self, value, item): @@ -1868,20 +1869,22 @@ class CharsetAccept(Accept): return codecs.lookup(name).name except LookupError: return name.lower() - return item == '*' or _normalize(value) == _normalize(item) + + return item == "*" or _normalize(value) == _normalize(item) def cache_property(key, empty, type): """Return a new property object for a cache header. Useful if you want to add support for a cache extension in a subclass.""" - return property(lambda x: x._get_cache_value(key, empty, type), - lambda x, v: x._set_cache_value(key, v, type), - lambda x: x._del_cache_value(key), - 'accessor for %r' % key) + return property( + lambda x: x._get_cache_value(key, empty, type), + lambda x, v: x._set_cache_value(key, v, type), + lambda x: x._del_cache_value(key), + "accessor for %r" % key, + ) class _CacheControl(UpdateDictMixin, dict): - """Subclass of a dict that stores values for a Cache-Control header. It has accessors for all the cache-control directives specified in RFC 2616. The class does not differentiate between request and response directives. @@ -1913,10 +1916,10 @@ class _CacheControl(UpdateDictMixin, dict): no longer existing `CacheControl` class. """ - no_cache = cache_property('no-cache', '*', None) - no_store = cache_property('no-store', None, bool) - max_age = cache_property('max-age', -1, int) - no_transform = cache_property('no-transform', None, None) + no_cache = cache_property("no-cache", "*", None) + no_store = cache_property("no-store", None, bool) + max_age = cache_property("max-age", -1, int) + no_transform = cache_property("no-transform", None, None) def __init__(self, values=(), on_update=None): dict.__init__(self, values or ()) @@ -1966,16 +1969,13 @@ class _CacheControl(UpdateDictMixin, dict): return self.to_header() def __repr__(self): - return '<%s %s>' % ( + return "<%s %s>" % ( self.__class__.__name__, - " ".join( - "%s=%r" % (k, v) for k, v in sorted(self.items()) - ), + " ".join("%s=%r" % (k, v) for k, v in sorted(self.items())), ) class RequestCacheControl(ImmutableDictMixin, _CacheControl): - """A cache control for requests. This is immutable and gives access to all the request-relevant cache control headers. @@ -1989,14 +1989,13 @@ class RequestCacheControl(ImmutableDictMixin, _CacheControl): both for request and response. """ - max_stale = cache_property('max-stale', '*', int) - min_fresh = cache_property('min-fresh', '*', int) - no_transform = cache_property('no-transform', None, None) - only_if_cached = cache_property('only-if-cached', None, bool) + max_stale = cache_property("max-stale", "*", int) + min_fresh = cache_property("min-fresh", "*", int) + no_transform = cache_property("no-transform", None, None) + only_if_cached = cache_property("only-if-cached", None, bool) class ResponseCacheControl(_CacheControl): - """A cache control for responses. Unlike :class:`RequestCacheControl` this is mutable and gives access to response-relevant cache control headers. @@ -2011,11 +2010,11 @@ class ResponseCacheControl(_CacheControl): both for request and response. """ - public = cache_property('public', None, bool) - private = cache_property('private', '*', None) - must_revalidate = cache_property('must-revalidate', None, bool) - proxy_revalidate = cache_property('proxy-revalidate', None, bool) - s_maxage = cache_property('s-maxage', None, None) + public = cache_property("public", None, bool) + private = cache_property("private", "*", None) + must_revalidate = cache_property("must-revalidate", None, bool) + proxy_revalidate = cache_property("proxy-revalidate", None, bool) + s_maxage = cache_property("s-maxage", None, None) # attach cache_property to the _CacheControl as staticmethod @@ -2024,7 +2023,6 @@ _CacheControl.cache_property = staticmethod(cache_property) class CallbackDict(UpdateDictMixin, dict): - """A dict that calls a function passed every time something is changed. The function is passed the dict instance. """ @@ -2034,14 +2032,10 @@ class CallbackDict(UpdateDictMixin, dict): self.on_update = on_update def __repr__(self): - return '<%s %s>' % ( - self.__class__.__name__, - dict.__repr__(self) - ) + return "<%s %s>" % (self.__class__.__name__, dict.__repr__(self)) class HeaderSet(collections_abc.MutableSet): - """Similar to the :class:`ETags` class this implements a set-like structure. Unlike :class:`ETags` this is case insensitive and used for vary, allow, and content-language headers. @@ -2153,7 +2147,7 @@ class HeaderSet(collections_abc.MutableSet): def to_header(self): """Convert the header set into an HTTP header string.""" - return ', '.join(map(quote_header_value, self._headers)) + return ", ".join(map(quote_header_value, self._headers)) def __getitem__(self, idx): return self._headers[idx] @@ -2188,14 +2182,10 @@ class HeaderSet(collections_abc.MutableSet): return self.to_header() def __repr__(self): - return '%s(%r)' % ( - self.__class__.__name__, - self._headers - ) + return "%s(%r)" % (self.__class__.__name__, self._headers) class ETags(collections_abc.Container, collections_abc.Iterable): - """A set that can be used to check if one etag is present in a collection of etags. """ @@ -2245,15 +2235,14 @@ class ETags(collections_abc.Container, collections_abc.Iterable): def to_header(self): """Convert the etags set into a HTTP header string.""" if self.star_tag: - return '*' - return ', '.join( - ['"%s"' % x for x in self._strong] - + ['W/"%s"' % x for x in self._weak] + return "*" + return ", ".join( + ['"%s"' % x for x in self._strong] + ['W/"%s"' % x for x in self._weak] ) def __call__(self, etag=None, data=None, include_weak=False): if [etag, data].count(None) != 1: - raise TypeError('either tag or data required, but at least one') + raise TypeError("either tag or data required, but at least one") if etag is None: etag = generate_etag(data) if include_weak: @@ -2276,11 +2265,10 @@ class ETags(collections_abc.Container, collections_abc.Iterable): return self.contains(etag) def __repr__(self): - return '<%s %r>' % (self.__class__.__name__, str(self)) + return "<%s %r>" % (self.__class__.__name__, str(self)) class IfRange(object): - """Very simple object that represents the `If-Range` header in parsed form. It will either have neither a etag or date or one of either but never both. @@ -2301,13 +2289,13 @@ class IfRange(object): return http_date(self.date) if self.etag is not None: return quote_etag(self.etag) - return '' + return "" def __str__(self): return self.to_header() def __repr__(self): - return '<%s %r>' % (self.__class__.__name__, str(self)) + return "<%s %r>" % (self.__class__.__name__, str(self)) class Range(object): @@ -2339,7 +2327,7 @@ class Range(object): exactly one range and it is satisfiable it returns a ``(start, stop)`` tuple, otherwise `None`. """ - if self.units != 'bytes' or length is None or len(self.ranges) != 1: + if self.units != "bytes" or length is None or len(self.ranges) != 1: return None start, end = self.ranges[0] if end is None: @@ -2362,10 +2350,10 @@ class Range(object): ranges = [] for begin, end in self.ranges: if end is None: - ranges.append('%s-' % begin if begin >= 0 else str(begin)) + ranges.append("%s-" % begin if begin >= 0 else str(begin)) else: - ranges.append('%s-%s' % (begin, end - 1)) - return '%s=%s' % (self.units, ','.join(ranges)) + ranges.append("%s-%s" % (begin, end - 1)) + return "%s=%s" % (self.units, ",".join(ranges)) def to_content_range_header(self, length): """Converts the object into `Content-Range` HTTP header, @@ -2373,32 +2361,33 @@ class Range(object): """ range_for_length = self.range_for_length(length) if range_for_length is not None: - return '%s %d-%d/%d' % (self.units, - range_for_length[0], - range_for_length[1] - 1, length) + return "%s %d-%d/%d" % ( + self.units, + range_for_length[0], + range_for_length[1] - 1, + length, + ) return None def __str__(self): return self.to_header() def __repr__(self): - return '<%s %r>' % (self.__class__.__name__, str(self)) + return "<%s %r>" % (self.__class__.__name__, str(self)) class ContentRange(object): - """Represents the content range header. .. versionadded:: 0.7 """ def __init__(self, units, start, stop, length=None, on_update=None): - assert is_byte_range_valid(start, stop, length), \ - 'Bad range provided' + assert is_byte_range_valid(start, stop, length), "Bad range provided" self.on_update = on_update self.set(start, stop, length, units) - def _callback_property(name): + def _callback_property(name): # noqa: B902 def fget(self): return getattr(self, name) @@ -2406,22 +2395,23 @@ class ContentRange(object): setattr(self, name, value) if self.on_update is not None: self.on_update(self) + return property(fget, fset) #: The units to use, usually "bytes" - units = _callback_property('_units') + units = _callback_property("_units") #: The start point of the range or `None`. - start = _callback_property('_start') + start = _callback_property("_start") #: The stop point of the range (non-inclusive) or `None`. Can only be #: `None` if also start is `None`. - stop = _callback_property('_stop') + stop = _callback_property("_stop") #: The length of the range or `None`. - length = _callback_property('_length') + length = _callback_property("_length") + del _callback_property - def set(self, start, stop, length=None, units='bytes'): + def set(self, start, stop, length=None, units="bytes"): """Simple method to update the ranges.""" - assert is_byte_range_valid(start, stop, length), \ - 'Bad range provided' + assert is_byte_range_valid(start, stop, length), "Bad range provided" self._units = units self._start = start self._stop = stop @@ -2437,19 +2427,14 @@ class ContentRange(object): def to_header(self): if self.units is None: - return '' + return "" if self.length is None: - length = '*' + length = "*" else: length = self.length if self.start is None: - return '%s */%s' % (self.units, length) - return '%s %s-%s/%s' % ( - self.units, - self.start, - self.stop - 1, - length - ) + return "%s */%s" % (self.units, length) + return "%s %s-%s/%s" % (self.units, self.start, self.stop - 1, length) def __nonzero__(self): return self.units is not None @@ -2460,11 +2445,10 @@ class ContentRange(object): return self.to_header() def __repr__(self): - return '<%s %r>' % (self.__class__.__name__, str(self)) + return "<%s %r>" % (self.__class__.__name__, str(self)) class Authorization(ImmutableDictMixin, dict): - """Represents an `Authorization` header sent by the client. You should not create this kind of object yourself but use it when it's returned by the `parse_authorization_header` function. @@ -2481,77 +2465,107 @@ class Authorization(ImmutableDictMixin, dict): dict.__init__(self, data or {}) self.type = auth_type - username = property(lambda x: x.get('username'), doc=''' + username = property( + lambda self: self.get("username"), + doc=""" The username transmitted. This is set for both basic and digest - auth all the time.''') - password = property(lambda x: x.get('password'), doc=''' + auth all the time.""", + ) + password = property( + lambda self: self.get("password"), + doc=""" When the authentication type is basic this is the password - transmitted by the client, else `None`.''') - realm = property(lambda x: x.get('realm'), doc=''' - This is the server realm sent back for HTTP digest auth.''') - nonce = property(lambda x: x.get('nonce'), doc=''' + transmitted by the client, else `None`.""", + ) + realm = property( + lambda self: self.get("realm"), + doc=""" + This is the server realm sent back for HTTP digest auth.""", + ) + nonce = property( + lambda self: self.get("nonce"), + doc=""" The nonce the server sent for digest auth, sent back by the client. A nonce should be unique for every 401 response for HTTP digest - auth.''') - uri = property(lambda x: x.get('uri'), doc=''' + auth.""", + ) + uri = property( + lambda self: self.get("uri"), + doc=""" The URI from Request-URI of the Request-Line; duplicated because proxies are allowed to change the Request-Line in transit. HTTP - digest auth only.''') - nc = property(lambda x: x.get('nc'), doc=''' + digest auth only.""", + ) + nc = property( + lambda self: self.get("nc"), + doc=""" The nonce count value transmitted by clients if a qop-header is - also transmitted. HTTP digest auth only.''') - cnonce = property(lambda x: x.get('cnonce'), doc=''' + also transmitted. HTTP digest auth only.""", + ) + cnonce = property( + lambda self: self.get("cnonce"), + doc=""" If the server sent a qop-header in the ``WWW-Authenticate`` header, the client has to provide this value for HTTP digest auth. - See the RFC for more details.''') - response = property(lambda x: x.get('response'), doc=''' + See the RFC for more details.""", + ) + response = property( + lambda self: self.get("response"), + doc=""" A string of 32 hex digits computed as defined in RFC 2617, which - proves that the user knows a password. Digest auth only.''') - opaque = property(lambda x: x.get('opaque'), doc=''' + proves that the user knows a password. Digest auth only.""", + ) + opaque = property( + lambda self: self.get("opaque"), + doc=""" The opaque header from the server returned unchanged by the client. It is recommended that this string be base64 or hexadecimal data. - Digest auth only.''') - qop = property(lambda x: x.get('qop'), doc=''' + Digest auth only.""", + ) + qop = property( + lambda self: self.get("qop"), + doc=""" Indicates what "quality of protection" the client has applied to the message for HTTP digest auth. Note that this is a single token, - not a quoted list of alternatives as in WWW-Authenticate.''') + not a quoted list of alternatives as in WWW-Authenticate.""", + ) class WWWAuthenticate(UpdateDictMixin, dict): - """Provides simple access to `WWW-Authenticate` headers.""" #: list of keys that require quoting in the generated header - _require_quoting = frozenset(['domain', 'nonce', 'opaque', 'realm', 'qop']) + _require_quoting = frozenset(["domain", "nonce", "opaque", "realm", "qop"]) def __init__(self, auth_type=None, values=None, on_update=None): dict.__init__(self, values or ()) if auth_type: - self['__auth_type__'] = auth_type + self["__auth_type__"] = auth_type self.on_update = on_update - def set_basic(self, realm='authentication required'): + def set_basic(self, realm="authentication required"): """Clear the auth info and enable basic auth.""" dict.clear(self) - dict.update(self, {'__auth_type__': 'basic', 'realm': realm}) + dict.update(self, {"__auth_type__": "basic", "realm": realm}) if self.on_update: self.on_update(self) - def set_digest(self, realm, nonce, qop=('auth',), opaque=None, - algorithm=None, stale=False): + def set_digest( + self, realm, nonce, qop=("auth",), opaque=None, algorithm=None, stale=False + ): """Clear the auth info and enable digest auth.""" d = { - '__auth_type__': 'digest', - 'realm': realm, - 'nonce': nonce, - 'qop': dump_header(qop) + "__auth_type__": "digest", + "realm": realm, + "nonce": nonce, + "qop": dump_header(qop), } if stale: - d['stale'] = 'TRUE' + d["stale"] = "TRUE" if opaque is not None: - d['opaque'] = opaque + d["opaque"] = opaque if algorithm is not None: - d['algorithm'] = algorithm + d["algorithm"] = algorithm dict.clear(self) dict.update(self, d) if self.on_update: @@ -2560,23 +2574,30 @@ class WWWAuthenticate(UpdateDictMixin, dict): def to_header(self): """Convert the stored values into a WWW-Authenticate header.""" d = dict(self) - auth_type = d.pop('__auth_type__', None) or 'basic' - return '%s %s' % (auth_type.title(), ', '.join([ - '%s=%s' % (key, quote_header_value(value, - allow_token=key not in self._require_quoting)) - for key, value in iteritems(d) - ])) + auth_type = d.pop("__auth_type__", None) or "basic" + return "%s %s" % ( + auth_type.title(), + ", ".join( + [ + "%s=%s" + % ( + key, + quote_header_value( + value, allow_token=key not in self._require_quoting + ), + ) + for key, value in iteritems(d) + ] + ), + ) def __str__(self): return self.to_header() def __repr__(self): - return '<%s %r>' % ( - self.__class__.__name__, - self.to_header() - ) + return "<%s %r>" % (self.__class__.__name__, self.to_header()) - def auth_property(name, doc=None): + def auth_property(name, doc=None): # noqa: B902 """A static helper function for subclasses to add extra authentication system properties onto a class:: @@ -2592,70 +2613,90 @@ class WWWAuthenticate(UpdateDictMixin, dict): self.pop(name, None) else: self[name] = str(value) + return property(lambda x: x.get(name), _set_value, doc=doc) - def _set_property(name, doc=None): + def _set_property(name, doc=None): # noqa: B902 def fget(self): def on_update(header_set): if not header_set and name in self: del self[name] elif header_set: self[name] = header_set.to_header() + return parse_set_header(self.get(name), on_update) + return property(fget, doc=doc) - type = auth_property('__auth_type__', doc=''' - The type of the auth mechanism. HTTP currently specifies - `Basic` and `Digest`.''') - realm = auth_property('realm', doc=''' - A string to be displayed to users so they know which username and - password to use. This string should contain at least the name of - the host performing the authentication and might additionally - indicate the collection of users who might have access.''') - domain = _set_property('domain', doc=''' - A list of URIs that define the protection space. If a URI is an - absolute path, it is relative to the canonical root URL of the - server being accessed.''') - nonce = auth_property('nonce', doc=''' + type = auth_property( + "__auth_type__", + doc="""The type of the auth mechanism. HTTP currently specifies + ``Basic`` and ``Digest``.""", + ) + realm = auth_property( + "realm", + doc="""A string to be displayed to users so they know which + username and password to use. This string should contain at + least the name of the host performing the authentication and + might additionally indicate the collection of users who might + have access.""", + ) + domain = _set_property( + "domain", + doc="""A list of URIs that define the protection space. If a URI + is an absolute path, it is relative to the canonical root URL of + the server being accessed.""", + ) + nonce = auth_property( + "nonce", + doc=""" A server-specified data string which should be uniquely generated - each time a 401 response is made. It is recommended that this - string be base64 or hexadecimal data.''') - opaque = auth_property('opaque', doc=''' - A string of data, specified by the server, which should be returned - by the client unchanged in the Authorization header of subsequent - requests with URIs in the same protection space. It is recommended - that this string be base64 or hexadecimal data.''') - algorithm = auth_property('algorithm', doc=''' - A string indicating a pair of algorithms used to produce the digest - and a checksum. If this is not present it is assumed to be "MD5". - If the algorithm is not understood, the challenge should be ignored - (and a different one used, if there is more than one).''') - qop = _set_property('qop', doc=''' - A set of quality-of-privacy directives such as auth and auth-int.''') - - def _get_stale(self): - val = self.get('stale') + each time a 401 response is made. It is recommended that this + string be base64 or hexadecimal data.""", + ) + opaque = auth_property( + "opaque", + doc="""A string of data, specified by the server, which should + be returned by the client unchanged in the Authorization header + of subsequent requests with URIs in the same protection space. + It is recommended that this string be base64 or hexadecimal + data.""", + ) + algorithm = auth_property( + "algorithm", + doc="""A string indicating a pair of algorithms used to produce + the digest and a checksum. If this is not present it is assumed + to be "MD5". If the algorithm is not understood, the challenge + should be ignored (and a different one used, if there is more + than one).""", + ) + qop = _set_property( + "qop", + doc="""A set of quality-of-privacy directives such as auth and + auth-int.""", + ) + + @property + def stale(self): + """A flag, indicating that the previous request from the client + was rejected because the nonce value was stale. + """ + val = self.get("stale") if val is not None: - return val.lower() == 'true' + return val.lower() == "true" - def _set_stale(self, value): + @stale.setter + def stale(self, value): if value is None: - self.pop('stale', None) + self.pop("stale", None) else: - self['stale'] = 'TRUE' if value else 'FALSE' - stale = property(_get_stale, _set_stale, doc=''' - A flag, indicating that the previous request from the client was - rejected because the nonce value was stale.''') - del _get_stale, _set_stale - - # make auth_property a staticmethod so that subclasses of - # `WWWAuthenticate` can use it for new properties. + self["stale"] = "TRUE" if value else "FALSE" + auth_property = staticmethod(auth_property) del _set_property class FileStorage(object): - """The :class:`FileStorage` class is a thin wrapper over incoming files. It is used by the request object to represent uploaded files. All the attributes of the wrapper stream are proxied by the file storage so @@ -2663,9 +2704,15 @@ class FileStorage(object): ``storage.stream.read()``. """ - def __init__(self, stream=None, filename=None, name=None, - content_type=None, content_length=None, - headers=None): + def __init__( + self, + stream=None, + filename=None, + name=None, + content_type=None, + content_length=None, + headers=None, + ): self.name = name self.stream = stream or BytesIO() @@ -2674,41 +2721,39 @@ class FileStorage(object): # skip things like <fdopen>, <stderr> etc. Python marks these # special filenames with angular brackets. if filename is None: - filename = getattr(stream, 'name', None) + filename = getattr(stream, "name", None) s = make_literal_wrapper(filename) - if filename and filename[0] == s('<') and filename[-1] == s('>'): + if filename and filename[0] == s("<") and filename[-1] == s(">"): filename = None # On Python 3 we want to make sure the filename is always unicode. # This might not be if the name attribute is bytes due to the # file being opened from the bytes API. if not PY2 and isinstance(filename, bytes): - filename = filename.decode(get_filesystem_encoding(), - 'replace') + filename = filename.decode(get_filesystem_encoding(), "replace") self.filename = filename if headers is None: headers = Headers() self.headers = headers if content_type is not None: - headers['Content-Type'] = content_type + headers["Content-Type"] = content_type if content_length is not None: - headers['Content-Length'] = str(content_length) + headers["Content-Length"] = str(content_length) def _parse_content_type(self): - if not hasattr(self, '_parsed_content_type'): - self._parsed_content_type = \ - parse_options_header(self.content_type) + if not hasattr(self, "_parsed_content_type"): + self._parsed_content_type = parse_options_header(self.content_type) @property def content_type(self): """The content-type sent in the header. Usually not available""" - return self.headers.get('content-type') + return self.headers.get("content-type") @property def content_length(self): """The content-length sent in the header. Usually not available""" - return int(self.headers.get('content-length') or 0) + return int(self.headers.get("content-length") or 0) @property def mimetype(self): @@ -2748,9 +2793,10 @@ class FileStorage(object): :func:`shutil.copyfileobj`. """ from shutil import copyfileobj + close_dst = False if isinstance(dst, string_types): - dst = open(dst, 'wb') + dst = open(dst, "wb") close_dst = True try: copyfileobj(self.stream, dst, buffer_size) @@ -2767,6 +2813,7 @@ class FileStorage(object): def __nonzero__(self): return bool(self.filename) + __bool__ = __nonzero__ def __getattr__(self, name): @@ -2784,15 +2831,22 @@ class FileStorage(object): return iter(self.stream) def __repr__(self): - return '<%s: %r (%r)>' % ( + return "<%s: %r (%r)>" % ( self.__class__.__name__, self.filename, - self.content_type + self.content_type, ) # circular dependencies -from werkzeug.http import dump_options_header, dump_header, generate_etag, \ - quote_header_value, parse_set_header, unquote_etag, quote_etag, \ - parse_options_header, http_date, is_byte_range_valid -from werkzeug import exceptions +from . import exceptions +from .http import dump_header +from .http import dump_options_header +from .http import generate_etag +from .http import http_date +from .http import is_byte_range_valid +from .http import parse_options_header +from .http import parse_set_header +from .http import quote_etag +from .http import quote_header_value +from .http import unquote_etag diff --git a/src/werkzeug/debug/__init__.py b/src/werkzeug/debug/__init__.py index cb984b03..beb7729e 100644 --- a/src/werkzeug/debug/__init__.py +++ b/src/werkzeug/debug/__init__.py @@ -8,33 +8,35 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ +import getpass +import hashlib +import json +import mimetypes import os import pkgutil import re import sys -import uuid -import json import time -import getpass -import hashlib -import mimetypes +import uuid from itertools import chain -from os.path import join, basename -from werkzeug.wrappers import BaseRequest as Request, BaseResponse as Response -from werkzeug.http import parse_cookie -from werkzeug.debug.tbtools import get_current_traceback, render_console_html -from werkzeug.debug.console import Console -from werkzeug.security import gen_salt -from werkzeug._internal import _log -from werkzeug._compat import text_type +from os.path import basename +from os.path import join - -# DEPRECATED -from werkzeug.debug.repr import debug_repr as _debug_repr +from .._compat import text_type +from .._internal import _log +from ..http import parse_cookie +from ..security import gen_salt +from ..wrappers import BaseRequest as Request +from ..wrappers import BaseResponse as Response +from .console import Console +from .repr import debug_repr as _debug_repr +from .tbtools import get_current_traceback +from .tbtools import render_console_html def debug_repr(*args, **kwargs): import warnings + warnings.warn( "'debug_repr' has moved to 'werkzeug.debug.repr.debug_repr'" " as of version 0.7. This old import will be removed in version" @@ -51,8 +53,8 @@ PIN_TIME = 60 * 60 * 24 * 7 def hash_pin(pin): if isinstance(pin, text_type): - pin = pin.encode('utf-8', 'replace') - return hashlib.md5(pin + b'shittysalt').hexdigest()[:12] + pin = pin.encode("utf-8", "replace") + return hashlib.md5(pin + b"shittysalt").hexdigest()[:12] _machine_id = None @@ -67,9 +69,9 @@ def get_machine_id(): def _generate(): # Potential sources of secret information on linux. The machine-id # is stable across boots, the boot id is not - for filename in '/etc/machine-id', '/proc/sys/kernel/random/boot_id': + for filename in "/etc/machine-id", "/proc/sys/kernel/random/boot_id": try: - with open(filename, 'rb') as f: + with open(filename, "rb") as f: return f.readline().strip() except IOError: continue @@ -81,8 +83,10 @@ def get_machine_id(): # Google App Engine # See https://github.com/pallets/werkzeug/issues/925 from subprocess import Popen, PIPE - dump = Popen(['ioreg', '-c', 'IOPlatformExpertDevice', '-d', '2'], - stdout=PIPE).communicate()[0] + + dump = Popen( + ["ioreg", "-c", "IOPlatformExpertDevice", "-d", "2"], stdout=PIPE + ).communicate()[0] match = re.search(b'"serial-number" = <([^>]+)', dump) if match is not None: return match.group(1) @@ -100,12 +104,15 @@ def get_machine_id(): pass if wr is not None: try: - with wr.OpenKey(wr.HKEY_LOCAL_MACHINE, - 'SOFTWARE\\Microsoft\\Cryptography', 0, - wr.KEY_READ | wr.KEY_WOW64_64KEY) as rk: - machineGuid, wrType = wr.QueryValueEx(rk, 'MachineGuid') + with wr.OpenKey( + wr.HKEY_LOCAL_MACHINE, + "SOFTWARE\\Microsoft\\Cryptography", + 0, + wr.KEY_READ | wr.KEY_WOW64_64KEY, + ) as rk: + machineGuid, wrType = wr.QueryValueEx(rk, "MachineGuid") if wrType == wr.REG_SZ: - return machineGuid.encode('utf-8') + return machineGuid.encode("utf-8") else: return machineGuid except WindowsError: @@ -116,7 +123,6 @@ def get_machine_id(): class _ConsoleFrame(object): - """Helper class so that we can reuse the frame console code for the standalone console. """ @@ -134,24 +140,23 @@ def get_pin_and_cookie_name(app): Second item in the resulting tuple is the cookie name for remembering. """ - pin = os.environ.get('WERKZEUG_DEBUG_PIN') + pin = os.environ.get("WERKZEUG_DEBUG_PIN") rv = None num = None # Pin was explicitly disabled - if pin == 'off': + if pin == "off": return None, None # Pin was provided explicitly - if pin is not None and pin.replace('-', '').isdigit(): + if pin is not None and pin.replace("-", "").isdigit(): # If there are separators in the pin, return it directly - if '-' in pin: + if "-" in pin: rv = pin else: num = pin - modname = getattr(app, '__module__', - getattr(app.__class__, '__module__')) + modname = getattr(app, "__module__", getattr(app.__class__, "__module__")) try: # `getpass.getuser()` imports the `pwd` module, @@ -167,42 +172,41 @@ def get_pin_and_cookie_name(app): probably_public_bits = [ username, modname, - getattr(app, '__name__', getattr(app.__class__, '__name__')), - getattr(mod, '__file__', None), + getattr(app, "__name__", getattr(app.__class__, "__name__")), + getattr(mod, "__file__", None), ] # This information is here to make it harder for an attacker to # guess the cookie name. They are unlikely to be contained anywhere # within the unauthenticated debug page. - private_bits = [ - str(uuid.getnode()), - get_machine_id(), - ] + private_bits = [str(uuid.getnode()), get_machine_id()] h = hashlib.md5() for bit in chain(probably_public_bits, private_bits): if not bit: continue if isinstance(bit, text_type): - bit = bit.encode('utf-8') + bit = bit.encode("utf-8") h.update(bit) - h.update(b'cookiesalt') + h.update(b"cookiesalt") - cookie_name = '__wzd' + h.hexdigest()[:20] + cookie_name = "__wzd" + h.hexdigest()[:20] # If we need to generate a pin we salt it a bit more so that we don't # end up with the same value and generate out 9 digits if num is None: - h.update(b'pinsalt') - num = ('%09d' % int(h.hexdigest(), 16))[:9] + h.update(b"pinsalt") + num = ("%09d" % int(h.hexdigest(), 16))[:9] # Format the pincode in groups of digits for easier remembering if # we don't have a result yet. if rv is None: for group_size in 5, 4, 3: if len(num) % group_size == 0: - rv = '-'.join(num[x:x + group_size].rjust(group_size, '0') - for x in range(0, len(num), group_size)) + rv = "-".join( + num[x : x + group_size].rjust(group_size, "0") + for x in range(0, len(num), group_size) + ) break else: rv = num @@ -240,12 +244,21 @@ class DebuggedApplication(object): :param pin_logging: enables the logging of the pin system. """ - def __init__(self, app, evalex=False, request_key='werkzeug.request', - console_path='/console', console_init_func=None, - show_hidden_frames=False, lodgeit_url=None, - pin_security=True, pin_logging=True): + def __init__( + self, + app, + evalex=False, + request_key="werkzeug.request", + console_path="/console", + console_init_func=None, + show_hidden_frames=False, + lodgeit_url=None, + pin_security=True, + pin_logging=True, + ): if lodgeit_url is not None: from warnings import warn + warn( "'lodgeit_url' is no longer used as of version 0.9 and" " will be removed in version 1.0. Werkzeug uses" @@ -269,19 +282,17 @@ class DebuggedApplication(object): self.pin_logging = pin_logging if pin_security: # Print out the pin for the debugger on standard out. - if os.environ.get('WERKZEUG_RUN_MAIN') == 'true' and \ - pin_logging: - _log('warning', ' * Debugger is active!') + if os.environ.get("WERKZEUG_RUN_MAIN") == "true" and pin_logging: + _log("warning", " * Debugger is active!") if self.pin is None: - _log('warning', ' * Debugger PIN disabled. ' - 'DEBUGGER UNSECURED!') + _log("warning", " * Debugger PIN disabled. DEBUGGER UNSECURED!") else: - _log('info', ' * Debugger PIN: %s' % self.pin) + _log("info", " * Debugger PIN: %s" % self.pin) else: self.pin = None def _get_pin(self): - if not hasattr(self, '_pin'): + if not hasattr(self, "_pin"): self._pin, self._pin_cookie = get_pin_and_cookie_name(self.app) return self._pin @@ -294,7 +305,7 @@ class DebuggedApplication(object): @property def pin_cookie_name(self): """The name of the pin cookie.""" - if not hasattr(self, '_pin_cookie'): + if not hasattr(self, "_pin_cookie"): self._pin, self._pin_cookie = get_pin_and_cookie_name(self.app) return self._pin_cookie @@ -305,46 +316,51 @@ class DebuggedApplication(object): app_iter = self.app(environ, start_response) for item in app_iter: yield item - if hasattr(app_iter, 'close'): + if hasattr(app_iter, "close"): app_iter.close() except Exception: - if hasattr(app_iter, 'close'): + if hasattr(app_iter, "close"): app_iter.close() traceback = get_current_traceback( - skip=1, show_hidden_frames=self.show_hidden_frames, - ignore_system_exceptions=True) + skip=1, + show_hidden_frames=self.show_hidden_frames, + ignore_system_exceptions=True, + ) for frame in traceback.frames: self.frames[frame.id] = frame self.tracebacks[traceback.id] = traceback try: - start_response('500 INTERNAL SERVER ERROR', [ - ('Content-Type', 'text/html; charset=utf-8'), - # Disable Chrome's XSS protection, the debug - # output can cause false-positives. - ('X-XSS-Protection', '0'), - ]) + start_response( + "500 INTERNAL SERVER ERROR", + [ + ("Content-Type", "text/html; charset=utf-8"), + # Disable Chrome's XSS protection, the debug + # output can cause false-positives. + ("X-XSS-Protection", "0"), + ], + ) except Exception: # if we end up here there has been output but an error # occurred. in that situation we can do nothing fancy any # more, better log something into the error log and fall # back gracefully. - environ['wsgi.errors'].write( - 'Debugging middleware caught exception in streamed ' - 'response at a point where response headers were already ' - 'sent.\n') + environ["wsgi.errors"].write( + "Debugging middleware caught exception in streamed " + "response at a point where response headers were already " + "sent.\n" + ) else: is_trusted = bool(self.check_pin_trust(environ)) - yield traceback.render_full(evalex=self.evalex, - evalex_trusted=is_trusted, - secret=self.secret) \ - .encode('utf-8', 'replace') + yield traceback.render_full( + evalex=self.evalex, evalex_trusted=is_trusted, secret=self.secret + ).encode("utf-8", "replace") - traceback.log(environ['wsgi.errors']) + traceback.log(environ["wsgi.errors"]) def execute_command(self, request, command, frame): """Execute a command in a console.""" - return Response(frame.console.eval(command), mimetype='text/html') + return Response(frame.console.eval(command), mimetype="text/html") def display_console(self, request): """Display a standalone shell.""" @@ -353,30 +369,30 @@ class DebuggedApplication(object): ns = {} else: ns = dict(self.console_init_func()) - ns.setdefault('app', self.app) + ns.setdefault("app", self.app) self.frames[0] = _ConsoleFrame(ns) is_trusted = bool(self.check_pin_trust(request.environ)) - return Response(render_console_html(secret=self.secret, - evalex_trusted=is_trusted), - mimetype='text/html') + return Response( + render_console_html(secret=self.secret, evalex_trusted=is_trusted), + mimetype="text/html", + ) def paste_traceback(self, request, traceback): """Paste the traceback and return a JSON response.""" rv = traceback.paste() - return Response(json.dumps(rv), mimetype='application/json') + return Response(json.dumps(rv), mimetype="application/json") def get_resource(self, request, filename): """Return a static resource from the shared folder.""" - filename = join('shared', basename(filename)) + filename = join("shared", basename(filename)) try: data = pkgutil.get_data(__package__, filename) except OSError: data = None if data is not None: - mimetype = mimetypes.guess_type(filename)[0] \ - or 'application/octet-stream' + mimetype = mimetypes.guess_type(filename)[0] or "application/octet-stream" return Response(data, mimetype=mimetype) - return Response('Not Found', status=404) + return Response("Not Found", status=404) def check_pin_trust(self, environ): """Checks if the request passed the pin test. This returns `True` if the @@ -387,9 +403,9 @@ class DebuggedApplication(object): if self.pin is None: return True val = parse_cookie(environ).get(self.pin_cookie_name) - if not val or '|' not in val: + if not val or "|" not in val: return False - ts, pin_hash = val.split('|', 1) + ts, pin_hash = val.split("|", 1) if not ts.isdigit(): return False if pin_hash != hash_pin(self.pin): @@ -426,23 +442,23 @@ class DebuggedApplication(object): # Otherwise go through pin based authentication else: - entered_pin = request.args.get('pin') - if entered_pin.strip().replace('-', '') == \ - self.pin.replace('-', ''): + entered_pin = request.args.get("pin") + if entered_pin.strip().replace("-", "") == self.pin.replace("-", ""): self._failed_pin_auth = 0 auth = True else: self._fail_pin_auth() - rv = Response(json.dumps({ - 'auth': auth, - 'exhausted': exhausted, - }), mimetype='application/json') + rv = Response( + json.dumps({"auth": auth, "exhausted": exhausted}), + mimetype="application/json", + ) if auth: - rv.set_cookie(self.pin_cookie_name, '%s|%s' % ( - int(time.time()), - hash_pin(self.pin) - ), httponly=True) + rv.set_cookie( + self.pin_cookie_name, + "%s|%s" % (int(time.time()), hash_pin(self.pin)), + httponly=True, + ) elif bad_cookie: rv.delete_cookie(self.pin_cookie_name) return rv @@ -450,10 +466,11 @@ class DebuggedApplication(object): def log_pin_request(self): """Log the pin if needed.""" if self.pin_logging and self.pin is not None: - _log('info', ' * To enable the debugger you need to ' - 'enter the security pin:') - _log('info', ' * Debugger pin code: %s' % self.pin) - return Response('') + _log( + "info", " * To enable the debugger you need to enter the security pin:" + ) + _log("info", " * Debugger pin code: %s" % self.pin) + return Response("") def __call__(self, environ, start_response): """Dispatch the requests.""" @@ -462,26 +479,32 @@ class DebuggedApplication(object): # any more! request = Request(environ) response = self.debug_application - if request.args.get('__debugger__') == 'yes': - cmd = request.args.get('cmd') - arg = request.args.get('f') - secret = request.args.get('s') - traceback = self.tracebacks.get(request.args.get('tb', type=int)) - frame = self.frames.get(request.args.get('frm', type=int)) - if cmd == 'resource' and arg: + if request.args.get("__debugger__") == "yes": + cmd = request.args.get("cmd") + arg = request.args.get("f") + secret = request.args.get("s") + traceback = self.tracebacks.get(request.args.get("tb", type=int)) + frame = self.frames.get(request.args.get("frm", type=int)) + if cmd == "resource" and arg: response = self.get_resource(request, arg) - elif cmd == 'paste' and traceback is not None and \ - secret == self.secret: + elif cmd == "paste" and traceback is not None and secret == self.secret: response = self.paste_traceback(request, traceback) - elif cmd == 'pinauth' and secret == self.secret: + elif cmd == "pinauth" and secret == self.secret: response = self.pin_auth(request) - elif cmd == 'printpin' and secret == self.secret: + elif cmd == "printpin" and secret == self.secret: response = self.log_pin_request() - elif self.evalex and cmd is not None and frame is not None \ - and self.secret == secret and \ - self.check_pin_trust(environ): + elif ( + self.evalex + and cmd is not None + and frame is not None + and self.secret == secret + and self.check_pin_trust(environ) + ): response = self.execute_command(request, cmd, frame) - elif self.evalex and self.console_path is not None and \ - request.path == self.console_path: + elif ( + self.evalex + and self.console_path is not None + and request.path == self.console_path + ): response = self.display_console(request) return response(environ, start_response) diff --git a/src/werkzeug/debug/console.py b/src/werkzeug/debug/console.py index 2915aea1..adbd170b 100644 --- a/src/werkzeug/debug/console.py +++ b/src/werkzeug/debug/console.py @@ -8,20 +8,21 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import sys import code +import sys from types import CodeType -from werkzeug.utils import escape -from werkzeug.local import Local -from werkzeug.debug.repr import debug_repr, dump, helper +from ..local import Local +from ..utils import escape +from .repr import debug_repr +from .repr import dump +from .repr import helper _local = Local() class HTMLStringO(object): - """A StringO version that HTML escapes on write.""" def __init__(self): @@ -41,46 +42,46 @@ class HTMLStringO(object): def readline(self): if len(self._buffer) == 0: - return '' + return "" ret = self._buffer[0] del self._buffer[0] return ret def reset(self): - val = ''.join(self._buffer) + val = "".join(self._buffer) del self._buffer[:] return val def _write(self, x): if isinstance(x, bytes): - x = x.decode('utf-8', 'replace') + x = x.decode("utf-8", "replace") self._buffer.append(x) def write(self, x): self._write(escape(x)) def writelines(self, x): - self._write(escape(''.join(x))) + self._write(escape("".join(x))) class ThreadedStream(object): - """Thread-local wrapper for sys.stdout for the interactive console.""" + @staticmethod def push(): if not isinstance(sys.stdout, ThreadedStream): sys.stdout = ThreadedStream() _local.stream = HTMLStringO() - push = staticmethod(push) + @staticmethod def fetch(): try: stream = _local.stream except AttributeError: - return '' + return "" return stream.reset() - fetch = staticmethod(fetch) + @staticmethod def displayhook(obj): try: stream = _local.stream @@ -89,18 +90,17 @@ class ThreadedStream(object): # stream._write bypasses escaping as debug_repr is # already generating HTML for us. if obj is not None: - _local._current_ipy.locals['_'] = obj + _local._current_ipy.locals["_"] = obj stream._write(debug_repr(obj)) - displayhook = staticmethod(displayhook) def __setattr__(self, name, value): - raise AttributeError('read only attribute %s' % name) + raise AttributeError("read only attribute %s" % name) def __dir__(self): return dir(sys.__stdout__) def __getattribute__(self, name): - if name == '__members__': + if name == "__members__": return dir(sys.__stdout__) try: stream = _local.stream @@ -118,7 +118,6 @@ sys.displayhook = ThreadedStream.displayhook class _ConsoleLoader(object): - def __init__(self): self._storage = {} @@ -143,29 +142,30 @@ def _wrap_compiler(console): code = compile(source, filename, symbol) console.loader.register(code, source) return code + console.compile = func class _InteractiveConsole(code.InteractiveInterpreter): - def __init__(self, globals, locals): code.InteractiveInterpreter.__init__(self, locals) self.globals = dict(globals) - self.globals['dump'] = dump - self.globals['help'] = helper - self.globals['__loader__'] = self.loader = _ConsoleLoader() + self.globals["dump"] = dump + self.globals["help"] = helper + self.globals["__loader__"] = self.loader = _ConsoleLoader() self.more = False self.buffer = [] _wrap_compiler(self) def runsource(self, source): - source = source.rstrip() + '\n' + source = source.rstrip() + "\n" ThreadedStream.push() - prompt = '... ' if self.more else '>>> ' + prompt = "... " if self.more else ">>> " try: - source_to_eval = ''.join(self.buffer + [source]) - if code.InteractiveInterpreter.runsource(self, - source_to_eval, '<debugger>', 'single'): + source_to_eval = "".join(self.buffer + [source]) + if code.InteractiveInterpreter.runsource( + self, source_to_eval, "<debugger>", "single" + ): self.more = True self.buffer.append(source) else: @@ -182,12 +182,14 @@ class _InteractiveConsole(code.InteractiveInterpreter): self.showtraceback() def showtraceback(self): - from werkzeug.debug.tbtools import get_current_traceback + from .tbtools import get_current_traceback + tb = get_current_traceback(skip=1) sys.stdout._write(tb.render_summary()) def showsyntaxerror(self, filename=None): - from werkzeug.debug.tbtools import get_current_traceback + from .tbtools import get_current_traceback + tb = get_current_traceback(skip=4) sys.stdout._write(tb.render_summary()) @@ -196,7 +198,6 @@ class _InteractiveConsole(code.InteractiveInterpreter): class Console(object): - """An interactive console.""" def __init__(self, globals=None, locals=None): diff --git a/src/werkzeug/debug/repr.py b/src/werkzeug/debug/repr.py index e82d87c4..d7a7285c 100644 --- a/src/werkzeug/debug/repr.py +++ b/src/werkzeug/debug/repr.py @@ -13,37 +13,38 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import sys -import re import codecs +import re +import sys +from collections import deque from traceback import format_exception_only -try: - from collections import deque -except ImportError: # pragma: no cover - deque = None -from werkzeug.utils import escape -from werkzeug._compat import iteritems, PY2, text_type, integer_types, \ - string_types + +from .._compat import integer_types +from .._compat import iteritems +from .._compat import PY2 +from .._compat import string_types +from .._compat import text_type +from ..utils import escape missing = object() -_paragraph_re = re.compile(r'(?:\r\n|\r|\n){2,}') +_paragraph_re = re.compile(r"(?:\r\n|\r|\n){2,}") RegexType = type(_paragraph_re) -HELP_HTML = '''\ +HELP_HTML = """\ <div class=box> <h3>%(title)s</h3> <pre class=help>%(text)s</pre> </div>\ -''' -OBJECT_DUMP_HTML = '''\ +""" +OBJECT_DUMP_HTML = """\ <div class=box> <h3>%(title)s</h3> %(repr)s <table>%(items)s</table> </div>\ -''' +""" def debug_repr(obj): @@ -64,31 +65,31 @@ def dump(obj=missing): class _Helper(object): - """Displays an HTML version of the normal help, for the interactive debugger only because it requires a patched sys.stdout. """ def __repr__(self): - return 'Type help(object) for help about object.' + return "Type help(object) for help about object." def __call__(self, topic=None): if topic is None: - sys.stdout._write('<span class=help>%s</span>' % repr(self)) + sys.stdout._write("<span class=help>%s</span>" % repr(self)) return import pydoc + pydoc.help(topic) rv = sys.stdout.reset() if isinstance(rv, bytes): - rv = rv.decode('utf-8', 'ignore') + rv = rv.decode("utf-8", "ignore") paragraphs = _paragraph_re.split(rv) if len(paragraphs) > 1: title = paragraphs[0] - text = '\n\n'.join(paragraphs[1:]) + text = "\n\n".join(paragraphs[1:]) else: # pragma: no cover - title = 'Help' + title = "Help" text = paragraphs[0] - sys.stdout._write(HELP_HTML % {'title': title, 'text': text}) + sys.stdout._write(HELP_HTML % {"title": title, "text": text}) helper = _Helper() @@ -101,55 +102,55 @@ def _add_subclass_info(inner, obj, base): return inner elif type(obj) is base: return inner - module = '' - if obj.__class__.__module__ not in ('__builtin__', 'exceptions'): + module = "" + if obj.__class__.__module__ not in ("__builtin__", "exceptions"): module = '<span class="module">%s.</span>' % obj.__class__.__module__ - return '%s%s(%s)' % (module, obj.__class__.__name__, inner) + return "%s%s(%s)" % (module, obj.__class__.__name__, inner) class DebugReprGenerator(object): - def __init__(self): self._stack = [] - def _sequence_repr_maker(left, right, base=object(), limit=8): + def _sequence_repr_maker(left, right, base=object(), limit=8): # noqa: B008, B902 def proxy(self, obj, recursive): if recursive: - return _add_subclass_info(left + '...' + right, obj, base) + return _add_subclass_info(left + "..." + right, obj, base) buf = [left] have_extended_section = False for idx, item in enumerate(obj): if idx: - buf.append(', ') + buf.append(", ") if idx == limit: buf.append('<span class="extended">') have_extended_section = True buf.append(self.repr(item)) if have_extended_section: - buf.append('</span>') + buf.append("</span>") buf.append(right) - return _add_subclass_info(u''.join(buf), obj, base) + return _add_subclass_info(u"".join(buf), obj, base) + return proxy - list_repr = _sequence_repr_maker('[', ']', list) - tuple_repr = _sequence_repr_maker('(', ')', tuple) - set_repr = _sequence_repr_maker('set([', '])', set) - frozenset_repr = _sequence_repr_maker('frozenset([', '])', frozenset) - if deque is not None: - deque_repr = _sequence_repr_maker('<span class="module">collections.' - '</span>deque([', '])', deque) + list_repr = _sequence_repr_maker("[", "]", list) + tuple_repr = _sequence_repr_maker("(", ")", tuple) + set_repr = _sequence_repr_maker("set([", "])", set) + frozenset_repr = _sequence_repr_maker("frozenset([", "])", frozenset) + deque_repr = _sequence_repr_maker( + '<span class="module">collections.' "</span>deque([", "])", deque + ) del _sequence_repr_maker def regex_repr(self, obj): pattern = repr(obj.pattern) if PY2: - pattern = pattern.decode('string-escape', 'ignore') + pattern = pattern.decode("string-escape", "ignore") else: - pattern = codecs.decode(pattern, 'unicode-escape', 'ignore') - if pattern[:1] == 'u': - pattern = 'ur' + pattern[1:] + pattern = codecs.decode(pattern, "unicode-escape", "ignore") + if pattern[:1] == "u": + pattern = "ur" + pattern[1:] else: - pattern = 'r' + pattern + pattern = "r" + pattern return u're.compile(<span class="string regex">%s</span>)' % pattern def string_repr(self, obj, limit=70): @@ -158,14 +159,18 @@ class DebugReprGenerator(object): # shorten the repr when the hidden part would be at least 3 chars if len(r) - limit > 2: - buf.extend(( - escape(r[:limit]), - '<span class="extended">', escape(r[limit:]), '</span>', - )) + buf.extend( + ( + escape(r[:limit]), + '<span class="extended">', + escape(r[limit:]), + "</span>", + ) + ) else: buf.append(escape(r)) - buf.append('</span>') + buf.append("</span>") out = u"".join(buf) # if the repr looks like a standard string, add subclass info if needed @@ -177,27 +182,29 @@ class DebugReprGenerator(object): def dict_repr(self, d, recursive, limit=5): if recursive: - return _add_subclass_info(u'{...}', d, dict) - buf = ['{'] + return _add_subclass_info(u"{...}", d, dict) + buf = ["{"] have_extended_section = False for idx, (key, value) in enumerate(iteritems(d)): if idx: - buf.append(', ') + buf.append(", ") if idx == limit - 1: buf.append('<span class="extended">') have_extended_section = True - buf.append('<span class="pair"><span class="key">%s</span>: ' - '<span class="value">%s</span></span>' % - (self.repr(key), self.repr(value))) + buf.append( + '<span class="pair"><span class="key">%s</span>: ' + '<span class="value">%s</span></span>' + % (self.repr(key), self.repr(value)) + ) if have_extended_section: - buf.append('</span>') - buf.append('}') - return _add_subclass_info(u''.join(buf), d, dict) + buf.append("</span>") + buf.append("}") + return _add_subclass_info(u"".join(buf), d, dict) def object_repr(self, obj): r = repr(obj) if PY2: - r = r.decode('utf-8', 'replace') + r = r.decode("utf-8", "replace") return u'<span class="object">%s</span>' % escape(r) def dispatch_repr(self, obj, recursive): @@ -225,13 +232,14 @@ class DebugReprGenerator(object): def fallback_repr(self): try: - info = ''.join(format_exception_only(*sys.exc_info()[:2])) + info = "".join(format_exception_only(*sys.exc_info()[:2])) except Exception: # pragma: no cover - info = '?' + info = "?" if PY2: - info = info.decode('utf-8', 'ignore') - return u'<span class="brokenrepr"><broken repr (%s)>' \ - u'</span>' % escape(info.strip()) + info = info.decode("utf-8", "ignore") + return u'<span class="brokenrepr"><broken repr (%s)>' u"</span>" % escape( + info.strip() + ) def repr(self, obj): recursive = False @@ -251,7 +259,7 @@ class DebugReprGenerator(object): def dump_object(self, obj): repr = items = None if isinstance(obj, dict): - title = 'Contents of' + title = "Contents of" items = [] for key, value in iteritems(obj): if not isinstance(key, string_types): @@ -266,23 +274,24 @@ class DebugReprGenerator(object): items.append((key, self.repr(getattr(obj, key)))) except Exception: pass - title = 'Details for' - title += ' ' + object.__repr__(obj)[1:-1] + title = "Details for" + title += " " + object.__repr__(obj)[1:-1] return self.render_object_dump(items, title, repr) def dump_locals(self, d): items = [(key, self.repr(value)) for key, value in d.items()] - return self.render_object_dump(items, 'Local variables in frame') + return self.render_object_dump(items, "Local variables in frame") def render_object_dump(self, items, title, repr=None): html_items = [] for key, value in items: - html_items.append('<tr><th>%s<td><pre class=repr>%s</pre>' % - (escape(key), value)) + html_items.append( + "<tr><th>%s<td><pre class=repr>%s</pre>" % (escape(key), value) + ) if not html_items: - html_items.append('<tr><td><em>Nothing</em>') + html_items.append("<tr><td><em>Nothing</em>") return OBJECT_DUMP_HTML % { - 'title': escape(title), - 'repr': '<pre class=repr>%s</pre>' % repr if repr else '', - 'items': '\n'.join(html_items) + "title": escape(title), + "repr": "<pre class=repr>%s</pre>" % repr if repr else "", + "items": "\n".join(html_items), } diff --git a/src/werkzeug/debug/tbtools.py b/src/werkzeug/debug/tbtools.py index 9422052f..f6af4e30 100644 --- a/src/werkzeug/debug/tbtools.py +++ b/src/werkzeug/debug/tbtools.py @@ -8,30 +8,33 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import re - +import codecs +import inspect +import json import os +import re import sys -import json -import inspect import sysconfig import traceback -import codecs from tokenize import TokenError -from werkzeug.utils import cached_property, escape -from werkzeug.debug.console import Console -from werkzeug._compat import ( - range_type, PY2, text_type, string_types, - to_native, to_unicode, reraise, -) -from werkzeug.filesystem import get_filesystem_encoding +from .._compat import PY2 +from .._compat import range_type +from .._compat import reraise +from .._compat import string_types +from .._compat import text_type +from .._compat import to_native +from .._compat import to_unicode +from ..filesystem import get_filesystem_encoding +from ..utils import cached_property +from ..utils import escape +from .console import Console -_coding_re = re.compile(br'coding[:=]\s*([-\w.]+)') -_line_re = re.compile(br'^(.*?)$', re.MULTILINE) -_funcdef_re = re.compile(r'^(\s*def\s)|(.*(?<!\w)lambda(:|\s))|^(\s*@)') -UTF8_COOKIE = b'\xef\xbb\xbf' +_coding_re = re.compile(br"coding[:=]\s*([-\w.]+)") +_line_re = re.compile(br"^(.*?)$", re.MULTILINE) +_funcdef_re = re.compile(r"^(\s*def\s)|(.*(?<!\w)lambda(:|\s))|^(\s*@)") +UTF8_COOKIE = b"\xef\xbb\xbf" system_exceptions = (SystemExit, KeyboardInterrupt) try: @@ -40,7 +43,7 @@ except NameError: pass -HEADER = u'''\ +HEADER = u"""\ <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN" "http://www.w3.org/TR/html4/loose.dtd"> <html> @@ -65,8 +68,8 @@ HEADER = u'''\ </head> <body style="background-color: #fff"> <div class="debugger"> -''' -FOOTER = u'''\ +""" +FOOTER = u"""\ <div class="footer"> Brought to you by <strong class="arthur">DON'T PANIC</strong>, your friendly Werkzeug powered traceback interpreter. @@ -89,9 +92,11 @@ FOOTER = u'''\ </div> </body> </html> -''' +""" -PAGE_HTML = HEADER + u'''\ +PAGE_HTML = ( + HEADER + + u"""\ <h1>%(exception_type)s</h1> <div class="detail"> <p class="errormsg">%(exception)s</p> @@ -117,61 +122,69 @@ PAGE_HTML = HEADER + u'''\ execution (if the evalex feature is enabled), automatic pasting of the exceptions and much more.</span> </div> -''' + FOOTER + ''' +""" + + FOOTER + + """ <!-- %(plaintext_cs)s --> -''' +""" +) -CONSOLE_HTML = HEADER + u'''\ +CONSOLE_HTML = ( + HEADER + + u"""\ <h1>Interactive Console</h1> <div class="explanation"> In this console you can execute Python expressions in the context of the application. The initial namespace was created by the debugger automatically. </div> <div class="console"><div class="inner">The Console requires JavaScript.</div></div> -''' + FOOTER +""" + + FOOTER +) -SUMMARY_HTML = u'''\ +SUMMARY_HTML = u"""\ <div class="%(classes)s"> %(title)s <ul>%(frames)s</ul> %(description)s </div> -''' +""" -FRAME_HTML = u'''\ +FRAME_HTML = u"""\ <div class="frame" id="frame-%(id)d"> <h4>File <cite class="filename">"%(filename)s"</cite>, line <em class="line">%(lineno)s</em>, in <code class="function">%(function_name)s</code></h4> <div class="source %(library)s">%(lines)s</div> </div> -''' +""" -SOURCE_LINE_HTML = u'''\ +SOURCE_LINE_HTML = u"""\ <tr class="%(classes)s"> <td class=lineno>%(lineno)s</td> <td>%(code)s</td> </tr> -''' +""" def render_console_html(secret, evalex_trusted=True): return CONSOLE_HTML % { - 'evalex': 'true', - 'evalex_trusted': 'true' if evalex_trusted else 'false', - 'console': 'true', - 'title': 'Console', - 'secret': secret, - 'traceback_id': -1 + "evalex": "true", + "evalex_trusted": "true" if evalex_trusted else "false", + "console": "true", + "title": "Console", + "secret": secret, + "traceback_id": -1, } -def get_current_traceback(ignore_system_exceptions=False, - show_hidden_frames=False, skip=0): +def get_current_traceback( + ignore_system_exceptions=False, show_hidden_frames=False, skip=0 +): """Get the current exception info as `Traceback` object. Per default calling this method will reraise system exceptions such as generator exit, system exit or others. This behavior can be disabled by passing `False` @@ -192,7 +205,8 @@ def get_current_traceback(ignore_system_exceptions=False, class Line(object): """Helper for the source renderer.""" - __slots__ = ('lineno', 'code', 'in_frame', 'current') + + __slots__ = ("lineno", "code", "in_frame", "current") def __init__(self, lineno, code): self.lineno = lineno @@ -202,18 +216,18 @@ class Line(object): @property def classes(self): - rv = ['line'] + rv = ["line"] if self.in_frame: - rv.append('in-frame') + rv.append("in-frame") if self.current: - rv.append('current') + rv.append("current") return rv def render(self): return SOURCE_LINE_HTML % { - 'classes': u' '.join(self.classes), - 'lineno': self.lineno, - 'code': escape(self.code) + "classes": u" ".join(self.classes), + "lineno": self.lineno, + "code": escape(self.code), } @@ -227,7 +241,7 @@ class Traceback(object): exception_type = exc_type.__name__ if exc_type.__module__ not in {"builtins", "__builtin__", "exceptions"}: - exception_type = exc_type.__module__ + '.' + exception_type + exception_type = exc_type.__module__ + "." + exception_type self.exception_type = exception_type self.groups = [] @@ -264,38 +278,33 @@ class Traceback(object): """Log the ASCII traceback into a file object.""" if logfile is None: logfile = sys.stderr - tb = self.plaintext.rstrip() + u'\n' + tb = self.plaintext.rstrip() + u"\n" logfile.write(to_native(tb, "utf-8", "replace")) def paste(self): """Create a paste and return the paste id.""" - data = json.dumps({ - 'description': 'Werkzeug Internal Server Error', - 'public': False, - 'files': { - 'traceback.txt': { - 'content': self.plaintext - } + data = json.dumps( + { + "description": "Werkzeug Internal Server Error", + "public": False, + "files": {"traceback.txt": {"content": self.plaintext}}, } - }).encode('utf-8') + ).encode("utf-8") try: from urllib2 import urlopen except ImportError: from urllib.request import urlopen - rv = urlopen('https://api.github.com/gists', data=data) - resp = json.loads(rv.read().decode('utf-8')) + rv = urlopen("https://api.github.com/gists", data=data) + resp = json.loads(rv.read().decode("utf-8")) rv.close() - return { - 'url': resp['html_url'], - 'id': resp['id'] - } + return {"url": resp["html_url"], "id": resp["id"]} def render_summary(self, include_title=True): """Render the traceback for the interactive console.""" - title = '' - classes = ['traceback'] + title = "" + classes = ["traceback"] if not self.frames: - classes.append('noframe-traceback') + classes.append("noframe-traceback") frames = [] else: library_frames = sum(frame.is_library for frame in self.frames) @@ -304,38 +313,37 @@ class Traceback(object): if include_title: if self.is_syntax_error: - title = u'Syntax Error' + title = u"Syntax Error" else: - title = u'Traceback <em>(most recent call last)</em>:' + title = u"Traceback <em>(most recent call last)</em>:" if self.is_syntax_error: - description_wrapper = u'<pre class=syntaxerror>%s</pre>' + description_wrapper = u"<pre class=syntaxerror>%s</pre>" else: - description_wrapper = u'<blockquote>%s</blockquote>' + description_wrapper = u"<blockquote>%s</blockquote>" return SUMMARY_HTML % { - 'classes': u' '.join(classes), - 'title': u'<h3>%s</h3>' % title if title else u'', - 'frames': u'\n'.join(frames), - 'description': description_wrapper % escape(self.exception) + "classes": u" ".join(classes), + "title": u"<h3>%s</h3>" % title if title else u"", + "frames": u"\n".join(frames), + "description": description_wrapper % escape(self.exception), } - def render_full(self, evalex=False, secret=None, - evalex_trusted=True): + def render_full(self, evalex=False, secret=None, evalex_trusted=True): """Render the Full HTML page with the traceback info.""" exc = escape(self.exception) return PAGE_HTML % { - 'evalex': 'true' if evalex else 'false', - 'evalex_trusted': 'true' if evalex_trusted else 'false', - 'console': 'false', - 'title': exc, - 'exception': exc, - 'exception_type': escape(self.exception_type), - 'summary': self.render_summary(include_title=False), - 'plaintext': escape(self.plaintext), - 'plaintext_cs': re.sub('-{2,}', '-', self.plaintext), - 'traceback_id': self.id, - 'secret': secret + "evalex": "true" if evalex else "false", + "evalex_trusted": "true" if evalex_trusted else "false", + "console": "false", + "title": exc, + "exception": exc, + "exception_type": escape(self.exception_type), + "summary": self.render_summary(include_title=False), + "plaintext": escape(self.plaintext), + "plaintext_cs": re.sub("-{2,}", "-", self.plaintext), + "traceback_id": self.id, + "secret": secret, } @cached_property @@ -418,10 +426,13 @@ class Group(object): if self.info is not None: out.append(u'<li><div class="exc-divider">%s:</div>' % self.info) for frame in self.frames: - out.append(u"<li%s>%s" % ( - u' title="%s"' % escape(frame.info) if frame.info else u"", - frame.render(mark_lib=mark_lib) - )) + out.append( + u"<li%s>%s" + % ( + u' title="%s"' % escape(frame.info) if frame.info else u"", + frame.render(mark_lib=mark_lib), + ) + ) return u"\n".join(out) def render_text(self): @@ -436,7 +447,6 @@ class Group(object): class Frame(object): - """A single frame in a traceback.""" def __init__(self, exc_type, exc_value, tb): @@ -446,19 +456,19 @@ class Frame(object): self.globals = tb.tb_frame.f_globals fn = inspect.getsourcefile(tb) or inspect.getfile(tb) - if fn[-4:] in ('.pyo', '.pyc'): + if fn[-4:] in (".pyo", ".pyc"): fn = fn[:-1] # if it's a file on the file system resolve the real filename. if os.path.isfile(fn): fn = os.path.realpath(fn) self.filename = to_unicode(fn, get_filesystem_encoding()) - self.module = self.globals.get('__name__') - self.loader = self.globals.get('__loader__') + self.module = self.globals.get("__name__") + self.loader = self.globals.get("__loader__") self.code = tb.tb_frame.f_code # support for paste's traceback extensions - self.hide = self.locals.get('__traceback_hide__', False) - info = self.locals.get('__traceback_info__') + self.hide = self.locals.get("__traceback_hide__", False) + info = self.locals.get("__traceback_info__") if info is not None: info = to_unicode(info, "utf-8", "replace") self.info = info @@ -466,11 +476,11 @@ class Frame(object): def render(self, mark_lib=True): """Render a single frame in a traceback.""" return FRAME_HTML % { - 'id': self.id, - 'filename': escape(self.filename), - 'lineno': self.lineno, - 'function_name': escape(self.function_name), - 'lines': self.render_line_context(), + "id": self.id, + "filename": escape(self.filename), + "lineno": self.lineno, + "function_name": escape(self.function_name), + "lines": self.render_line_context(), "library": "library" if mark_lib and self.is_library else "", } @@ -485,7 +495,7 @@ class Frame(object): self.filename, self.lineno, self.function_name, - self.current_line.strip() + self.current_line.strip(), ) def render_line_context(self): @@ -497,34 +507,34 @@ class Frame(object): stripped_line = line.strip() prefix = len(line) - len(stripped_line) rv.append( - '<pre class="line %s"><span class="ws">%s</span>%s</pre>' % ( - cls, ' ' * prefix, escape(stripped_line) or ' ')) + '<pre class="line %s"><span class="ws">%s</span>%s</pre>' + % (cls, " " * prefix, escape(stripped_line) or " ") + ) for line in before: - render_line(line, 'before') - render_line(current, 'current') + render_line(line, "before") + render_line(current, "current") for line in after: - render_line(line, 'after') + render_line(line, "after") - return '\n'.join(rv) + return "\n".join(rv) def get_annotated_lines(self): """Helper function that returns lines with extra information.""" lines = [Line(idx + 1, x) for idx, x in enumerate(self.sourcelines)] # find function definition and mark lines - if hasattr(self.code, 'co_firstlineno'): + if hasattr(self.code, "co_firstlineno"): lineno = self.code.co_firstlineno - 1 while lineno > 0: if _funcdef_re.match(lines[lineno].code): break lineno -= 1 try: - offset = len(inspect.getblock([x.code + '\n' for x - in lines[lineno:]])) + offset = len(inspect.getblock([x.code + "\n" for x in lines[lineno:]])) except TokenError: offset = 0 - for line in lines[lineno:lineno + offset]: + for line in lines[lineno : lineno + offset]: line.in_frame = True # mark current line @@ -535,12 +545,12 @@ class Frame(object): return lines - def eval(self, code, mode='single'): + def eval(self, code, mode="single"): """Evaluate code in the context of the frame.""" if isinstance(code, string_types): if PY2 and isinstance(code, text_type): # noqa - code = UTF8_COOKIE + code.encode('utf-8') - code = compile(code, '<interactive>', mode) + code = UTF8_COOKIE + code.encode("utf-8") + code = compile(code, "<interactive>", mode) return eval(code, self.globals, self.locals) @cached_property @@ -550,9 +560,9 @@ class Frame(object): source = None if self.loader is not None: try: - if hasattr(self.loader, 'get_source'): + if hasattr(self.loader, "get_source"): source = self.loader.get_source(self.module) - elif hasattr(self.loader, 'get_source_by_code'): + elif hasattr(self.loader, "get_source_by_code"): source = self.loader.get_source_by_code(self.code) except Exception: # we munch the exception so that we don't cause troubles @@ -561,8 +571,7 @@ class Frame(object): if source is None: try: - f = open(to_native(self.filename, get_filesystem_encoding()), - mode='rb') + f = open(to_native(self.filename, get_filesystem_encoding()), mode="rb") except IOError: return [] try: @@ -576,7 +585,7 @@ class Frame(object): # yes. it should be ascii, but we don't want to reject too many # characters in the debugger if something breaks - charset = 'utf-8' + charset = "utf-8" if source.startswith(UTF8_COOKIE): source = source[3:] else: @@ -593,25 +602,21 @@ class Frame(object): try: codecs.lookup(charset) except LookupError: - charset = 'utf-8' + charset = "utf-8" - return source.decode(charset, 'replace').splitlines() + return source.decode(charset, "replace").splitlines() def get_context_lines(self, context=5): - before = self.sourcelines[self.lineno - context - 1:self.lineno - 1] - past = self.sourcelines[self.lineno:self.lineno + context] - return ( - before, - self.current_line, - past, - ) + before = self.sourcelines[self.lineno - context - 1 : self.lineno - 1] + past = self.sourcelines[self.lineno : self.lineno + context] + return (before, self.current_line, past) @property def current_line(self): try: return self.sourcelines[self.lineno - 1] except IndexError: - return u'' + return u"" @cached_property def console(self): diff --git a/src/werkzeug/exceptions.py b/src/werkzeug/exceptions.py index e2c2cb3e..0e6d1ce3 100644 --- a/src/werkzeug/exceptions.py +++ b/src/werkzeug/exceptions.py @@ -59,23 +59,23 @@ """ import sys +import werkzeug + # Because of bootstrapping reasons we need to manually patch ourselves # onto our parent module. -import werkzeug werkzeug.exceptions = sys.modules[__name__] -from werkzeug._internal import _get_environ -from werkzeug._compat import iteritems, integer_types, text_type, \ - implements_to_string - -from werkzeug.wrappers import Response +from ._compat import implements_to_string +from ._compat import integer_types +from ._compat import iteritems +from ._compat import text_type +from ._internal import _get_environ +from .wrappers import Response @implements_to_string class HTTPException(Exception): - - """ - Baseclass for all HTTP exceptions. This exception can be called as WSGI + """Baseclass for all HTTP exceptions. This exception can be called as WSGI application to render a default error page or you can catch the subclasses of it independently and render nicer error messages. """ @@ -102,6 +102,7 @@ class HTTPException(Exception): .. versionchanged:: 0.15 The description includes the wrapped exception message. """ + class newcls(cls, exception): def __init__(self, arg=None, *args, **kwargs): super(cls, self).__init__(*args, **kwargs) @@ -116,41 +117,43 @@ class HTTPException(Exception): if self.args: out += "<p><pre><code>{}: {}</code></pre></p>".format( - exception.__name__, - escape(exception.__str__(self)) + exception.__name__, escape(exception.__str__(self)) ) return out - newcls.__module__ = sys._getframe(1).f_globals.get('__name__') + newcls.__module__ = sys._getframe(1).f_globals.get("__name__") newcls.__name__ = name or cls.__name__ + exception.__name__ return newcls @property def name(self): """The status name.""" - return HTTP_STATUS_CODES.get(self.code, 'Unknown Error') + return HTTP_STATUS_CODES.get(self.code, "Unknown Error") def get_description(self, environ=None): """Get the description.""" - return u'<p>%s</p>' % escape(self.description) + return u"<p>%s</p>" % escape(self.description) def get_body(self, environ=None): """Get the HTML body.""" - return text_type(( - u'<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 3.2 Final//EN">\n' - u'<title>%(code)s %(name)s</title>\n' - u'<h1>%(name)s</h1>\n' - u'%(description)s\n' - ) % { - 'code': self.code, - 'name': escape(self.name), - 'description': self.get_description(environ) - }) + return text_type( + ( + u'<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 3.2 Final//EN">\n' + u"<title>%(code)s %(name)s</title>\n" + u"<h1>%(name)s</h1>\n" + u"%(description)s\n" + ) + % { + "code": self.code, + "name": escape(self.name), + "description": self.get_description(environ), + } + ) def get_headers(self, environ=None): """Get a list of headers.""" - return [('Content-Type', 'text/html')] + return [("Content-Type", "text/html")] def get_response(self, environ=None): """Get a response object. If one was passed to the exception @@ -179,30 +182,29 @@ class HTTPException(Exception): return response(environ, start_response) def __str__(self): - code = self.code if self.code is not None else '???' - return '%s %s: %s' % (code, self.name, self.description) + code = self.code if self.code is not None else "???" + return "%s %s: %s" % (code, self.name, self.description) def __repr__(self): - code = self.code if self.code is not None else '???' + code = self.code if self.code is not None else "???" return "<%s '%s: %s'>" % (self.__class__.__name__, code, self.name) class BadRequest(HTTPException): - """*400* `Bad Request` Raise if the browser sends something to the application the application or server cannot handle. """ + code = 400 description = ( - 'The browser (or proxy) sent a request that this server could ' - 'not understand.' + "The browser (or proxy) sent a request that this server could " + "not understand." ) class ClientDisconnected(BadRequest): - """Internal exception that is raised if Werkzeug detects a disconnected client. Since the client is already gone at that point attempting to send the error message to the client might not work and might ultimately @@ -218,7 +220,6 @@ class ClientDisconnected(BadRequest): class SecurityError(BadRequest): - """Raised if something triggers a security error. This is otherwise exactly like a bad request error. @@ -227,7 +228,6 @@ class SecurityError(BadRequest): class BadHost(BadRequest): - """Raised if the submitted host is badly formatted. .. versionadded:: 0.11.2 @@ -235,7 +235,6 @@ class BadHost(BadRequest): class Unauthorized(HTTPException): - """*401* ``Unauthorized`` Raise if the user is not authorized to access a resource. @@ -252,12 +251,13 @@ class Unauthorized(HTTPException): :param description: Override the default message used for the body of the response. """ + code = 401 description = ( - 'The server could not verify that you are authorized to access' - ' the URL requested. You either supplied the wrong credentials' + "The server could not verify that you are authorized to access" + " the URL requested. You either supplied the wrong credentials" " (e.g. a bad password), or your browser doesn't understand" - ' how to supply the credentials required.' + " how to supply the credentials required." ) def __init__(self, www_authenticate=None, description=None): @@ -269,43 +269,41 @@ class Unauthorized(HTTPException): def get_headers(self, environ=None): headers = HTTPException.get_headers(self, environ) if self.www_authenticate: - headers.append(( - 'WWW-Authenticate', - ', '.join([str(x) for x in self.www_authenticate]) - )) + headers.append( + ("WWW-Authenticate", ", ".join([str(x) for x in self.www_authenticate])) + ) return headers class Forbidden(HTTPException): - """*403* `Forbidden` Raise if the user doesn't have the permission for the requested resource but was authenticated. """ + code = 403 description = ( - 'You don\'t have the permission to access the requested resource. ' - 'It is either read-protected or not readable by the server.' + "You don't have the permission to access the requested" + " resource. It is either read-protected or not readable by the" + " server." ) class NotFound(HTTPException): - """*404* `Not Found` Raise if a resource does not exist and never existed. """ + code = 404 description = ( - 'The requested URL was not found on the server. ' - 'If you entered the URL manually please check your spelling and ' - 'try again.' + "The requested URL was not found on the server. If you entered" + " the URL manually please check your spelling and try again." ) class MethodNotAllowed(HTTPException): - """*405* `Method Not Allowed` Raise if the server used a method the resource does not handle. For @@ -315,8 +313,9 @@ class MethodNotAllowed(HTTPException): Strictly speaking the response would be invalid if you don't provide valid methods in the header which you can do with that list. """ + code = 405 - description = 'The method is not allowed for the requested URL.' + description = "The method is not allowed for the requested URL." def __init__(self, valid_methods=None, description=None): """Takes an optional list of valid http methods @@ -327,42 +326,41 @@ class MethodNotAllowed(HTTPException): def get_headers(self, environ=None): headers = HTTPException.get_headers(self, environ) if self.valid_methods: - headers.append(('Allow', ', '.join(self.valid_methods))) + headers.append(("Allow", ", ".join(self.valid_methods))) return headers class NotAcceptable(HTTPException): - """*406* `Not Acceptable` Raise if the server can't return any content conforming to the `Accept` headers of the client. """ + code = 406 description = ( - 'The resource identified by the request is only capable of ' - 'generating response entities which have content characteristics ' - 'not acceptable according to the accept headers sent in the ' - 'request.' + "The resource identified by the request is only capable of" + " generating response entities which have content" + " characteristics not acceptable according to the accept" + " headers sent in the request." ) class RequestTimeout(HTTPException): - """*408* `Request Timeout` Raise to signalize a timeout. """ + code = 408 description = ( - 'The server closed the network connection because the browser ' - 'didn\'t finish the request within the specified time.' + "The server closed the network connection because the browser" + " didn't finish the request within the specified time." ) class Conflict(HTTPException): - """*409* `Conflict` Raise to signal that a request cannot be completed because it conflicts @@ -370,107 +368,103 @@ class Conflict(HTTPException): .. versionadded:: 0.7 """ + code = 409 description = ( - 'A conflict happened while processing the request. The resource ' - 'might have been modified while the request was being processed.' + "A conflict happened while processing the request. The" + " resource might have been modified while the request was being" + " processed." ) class Gone(HTTPException): - """*410* `Gone` Raise if a resource existed previously and went away without new location. """ + code = 410 description = ( - 'The requested URL is no longer available on this server and there ' - 'is no forwarding address. If you followed a link from a foreign ' - 'page, please contact the author of this page.' + "The requested URL is no longer available on this server and" + " there is no forwarding address. If you followed a link from a" + " foreign page, please contact the author of this page." ) class LengthRequired(HTTPException): - """*411* `Length Required` Raise if the browser submitted data but no ``Content-Length`` header which is required for the kind of processing the server does. """ + code = 411 description = ( - 'A request with this method requires a valid <code>Content-' - 'Length</code> header.' + "A request with this method requires a valid <code>Content-" + "Length</code> header." ) class PreconditionFailed(HTTPException): - """*412* `Precondition Failed` Status code used in combination with ``If-Match``, ``If-None-Match``, or ``If-Unmodified-Since``. """ + code = 412 description = ( - 'The precondition on the request for the URL failed positive ' - 'evaluation.' + "The precondition on the request for the URL failed positive evaluation." ) class RequestEntityTooLarge(HTTPException): - """*413* `Request Entity Too Large` The status code one should return if the data submitted exceeded a given limit. """ + code = 413 - description = ( - 'The data value transmitted exceeds the capacity limit.' - ) + description = "The data value transmitted exceeds the capacity limit." class RequestURITooLarge(HTTPException): - """*414* `Request URI Too Large` Like *413* but for too long URLs. """ + code = 414 description = ( - 'The length of the requested URL exceeds the capacity limit ' - 'for this server. The request cannot be processed.' + "The length of the requested URL exceeds the capacity limit for" + " this server. The request cannot be processed." ) class UnsupportedMediaType(HTTPException): - """*415* `Unsupported Media Type` The status code returned if the server is unable to handle the media type the client transmitted. """ + code = 415 description = ( - 'The server does not support the media type transmitted in ' - 'the request.' + "The server does not support the media type transmitted in the request." ) class RequestedRangeNotSatisfiable(HTTPException): - """*416* `Requested Range Not Satisfiable` The client asked for an invalid part of the file. .. versionadded:: 0.7 """ + code = 416 - description = ( - 'The server cannot provide the requested range.' - ) + description = "The server cannot provide the requested range." def __init__(self, length=None, units="bytes", description=None): """Takes an optional `Content-Range` header value based on ``length`` @@ -483,27 +477,23 @@ class RequestedRangeNotSatisfiable(HTTPException): def get_headers(self, environ=None): headers = HTTPException.get_headers(self, environ) if self.length is not None: - headers.append( - ('Content-Range', '%s */%d' % (self.units, self.length))) + headers.append(("Content-Range", "%s */%d" % (self.units, self.length))) return headers class ExpectationFailed(HTTPException): - """*417* `Expectation Failed` The server cannot meet the requirements of the Expect request-header. .. versionadded:: 0.7 """ + code = 417 - description = ( - 'The server could not meet the requirements of the Expect header' - ) + description = "The server could not meet the requirements of the Expect header" class ImATeapot(HTTPException): - """*418* `I'm a teapot` The server should return this if it is a teapot and someone attempted @@ -511,54 +501,51 @@ class ImATeapot(HTTPException): .. versionadded:: 0.7 """ + code = 418 - description = ( - 'This server is a teapot, not a coffee machine' - ) + description = "This server is a teapot, not a coffee machine" class UnprocessableEntity(HTTPException): - """*422* `Unprocessable Entity` Used if the request is well formed, but the instructions are otherwise incorrect. """ + code = 422 description = ( - 'The request was well-formed but was unable to be followed ' - 'due to semantic errors.' + "The request was well-formed but was unable to be followed due" + " to semantic errors." ) class Locked(HTTPException): - """*423* `Locked` Used if the resource that is being accessed is locked. """ + code = 423 - description = ( - 'The resource that is being accessed is locked.' - ) + description = "The resource that is being accessed is locked." class FailedDependency(HTTPException): - """*424* `Failed Dependency` Used if the method could not be performed on the resource because the requested action depended on another action and that action failed. """ + code = 424 description = ( - 'The method could not be performed on the resource because the requested action ' - 'depended on another action and that action failed.' + "The method could not be performed on the resource because the" + " requested action depended on another action and that action" + " failed." ) class PreconditionRequired(HTTPException): - """*428* `Precondition Required` The server requires this request to be conditional, typically to prevent @@ -569,15 +556,15 @@ class PreconditionRequired(HTTPException): server ensures that each client has at least seen the previous revision of the resource. """ + code = 428 description = ( - 'This request is required to be conditional; try using "If-Match" ' - 'or "If-Unmodified-Since".' + "This request is required to be conditional; try using" + ' "If-Match" or "If-Unmodified-Since".' ) class TooManyRequests(HTTPException): - """*429* `Too Many Requests` The server is limiting the rate at which this user receives responses, and @@ -586,129 +573,117 @@ class TooManyRequests(HTTPException): "Retry-After" header to indicate how long the user should wait before retrying. """ + code = 429 - description = ( - 'This user has exceeded an allotted request count. Try again later.' - ) + description = "This user has exceeded an allotted request count. Try again later." class RequestHeaderFieldsTooLarge(HTTPException): - """*431* `Request Header Fields Too Large` The server refuses to process the request because the header fields are too large. One or more individual fields may be too large, or the set of all headers is too large. """ + code = 431 - description = ( - 'One or more header fields exceeds the maximum size.' - ) + description = "One or more header fields exceeds the maximum size." class UnavailableForLegalReasons(HTTPException): - """*451* `Unavailable For Legal Reasons` This status code indicates that the server is denying access to the resource as a consequence of a legal demand. """ + code = 451 - description = ( - 'Unavailable for legal reasons.' - ) + description = "Unavailable for legal reasons." class InternalServerError(HTTPException): - """*500* `Internal Server Error` Raise if an internal server error occurred. This is a good fallback if an unknown error occurred in the dispatcher. """ + code = 500 description = ( - 'The server encountered an internal error and was unable to ' - 'complete your request. Either the server is overloaded or there ' - 'is an error in the application.' + "The server encountered an internal error and was unable to" + " complete your request. Either the server is overloaded or" + " there is an error in the application." ) class NotImplemented(HTTPException): - """*501* `Not Implemented` Raise if the application does not support the action requested by the browser. """ + code = 501 - description = ( - 'The server does not support the action requested by the ' - 'browser.' - ) + description = "The server does not support the action requested by the browser." class BadGateway(HTTPException): - """*502* `Bad Gateway` If you do proxying in your application you should return this status code if you received an invalid response from the upstream server it accessed in attempting to fulfill the request. """ + code = 502 description = ( - 'The proxy server received an invalid response from an upstream ' - 'server.' + "The proxy server received an invalid response from an upstream server." ) class ServiceUnavailable(HTTPException): - """*503* `Service Unavailable` Status code you should return if a service is temporarily unavailable. """ + code = 503 description = ( - 'The server is temporarily unable to service your request due to ' - 'maintenance downtime or capacity problems. Please try again ' - 'later.' + "The server is temporarily unable to service your request due" + " to maintenance downtime or capacity problems. Please try" + " again later." ) class GatewayTimeout(HTTPException): - """*504* `Gateway Timeout` Status code you should return if a connection to an upstream server times out. """ + code = 504 - description = ( - 'The connection to an upstream server timed out.' - ) + description = "The connection to an upstream server timed out." class HTTPVersionNotSupported(HTTPException): - """*505* `HTTP Version Not Supported` The server does not support the HTTP protocol version used in the request. """ + code = 505 description = ( - 'The server does not support the HTTP protocol version used in the ' - 'request.' + "The server does not support the HTTP protocol version used in the request." ) default_exceptions = {} -__all__ = ['HTTPException'] +__all__ = ["HTTPException"] def _find_exceptions(): - for name, obj in iteritems(globals()): + for _name, obj in iteritems(globals()): try: is_http_exception = issubclass(obj, HTTPException) except TypeError: @@ -720,14 +695,14 @@ def _find_exceptions(): if old_obj is not None and issubclass(obj, old_obj): continue default_exceptions[obj.code] = obj + + _find_exceptions() del _find_exceptions class Aborter(object): - - """ - When passed a dict of code -> exception items it can be used as + """When passed a dict of code -> exception items it can be used as callable that raises exceptions. If the first argument to the callable is an integer it will be looked up in the mapping, if it's a WSGI application it will be raised in a proxy exception. @@ -746,13 +721,12 @@ class Aborter(object): if not args and not kwargs and not isinstance(code, integer_types): raise HTTPException(response=code) if code not in self.mapping: - raise LookupError('no exception for %r' % code) + raise LookupError("no exception for %r" % code) raise self.mapping[code](*args, **kwargs) def abort(status, *args, **kwargs): - ''' - Raises an :py:exc:`HTTPException` for the given status code or WSGI + """Raises an :py:exc:`HTTPException` for the given status code or WSGI application:: abort(404) # 404 Not Found @@ -766,9 +740,10 @@ def abort(status, *args, **kwargs): abort(404) abort(Response('Hello World')) - ''' + """ return _aborter(status, *args, **kwargs) + _aborter = Aborter() @@ -777,5 +752,5 @@ _aborter = Aborter() BadRequestKeyError = BadRequest.wrap(KeyError) # imported here because of circular dependencies of werkzeug.utils -from werkzeug.utils import escape -from werkzeug.http import HTTP_STATUS_CODES +from .http import HTTP_STATUS_CODES +from .utils import escape diff --git a/src/werkzeug/filesystem.py b/src/werkzeug/filesystem.py index 927f37ba..d016caea 100644 --- a/src/werkzeug/filesystem.py +++ b/src/werkzeug/filesystem.py @@ -8,41 +8,39 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ - import codecs import sys import warnings # We do not trust traditional unixes. -has_likely_buggy_unicode_filesystem = \ - sys.platform.startswith('linux') or 'bsd' in sys.platform +has_likely_buggy_unicode_filesystem = ( + sys.platform.startswith("linux") or "bsd" in sys.platform +) def _is_ascii_encoding(encoding): - """ - Given an encoding this figures out if the encoding is actually ASCII (which + """Given an encoding this figures out if the encoding is actually ASCII (which is something we don't actually want in most cases). This is necessary because ASCII comes under many names such as ANSI_X3.4-1968. """ if encoding is None: return False try: - return codecs.lookup(encoding).name == 'ascii' + return codecs.lookup(encoding).name == "ascii" except LookupError: return False class BrokenFilesystemWarning(RuntimeWarning, UnicodeWarning): - '''The warning used by Werkzeug to signal a broken filesystem. Will only be - used once per runtime.''' + """The warning used by Werkzeug to signal a broken filesystem. Will only be + used once per runtime.""" _warned_about_filesystem_encoding = False def get_filesystem_encoding(): - """ - Returns the filesystem encoding that should be used. Note that this is + """Returns the filesystem encoding that should be used. Note that this is different from the Python understanding of the filesystem encoding which might be deeply flawed. Do not use this value against Python's unicode APIs because it might be different. See :ref:`filesystem-encoding` for the exact @@ -54,13 +52,13 @@ def get_filesystem_encoding(): """ global _warned_about_filesystem_encoding rv = sys.getfilesystemencoding() - if has_likely_buggy_unicode_filesystem and not rv \ - or _is_ascii_encoding(rv): + if has_likely_buggy_unicode_filesystem and not rv or _is_ascii_encoding(rv): if not _warned_about_filesystem_encoding: warnings.warn( - 'Detected a misconfigured UNIX filesystem: Will use UTF-8 as ' - 'filesystem encoding instead of {0!r}'.format(rv), - BrokenFilesystemWarning) + "Detected a misconfigured UNIX filesystem: Will use" + " UTF-8 as filesystem encoding instead of {0!r}".format(rv), + BrokenFilesystemWarning, + ) _warned_about_filesystem_encoding = True - return 'utf-8' + return "utf-8" return rv diff --git a/src/werkzeug/formparser.py b/src/werkzeug/formparser.py index 370c0751..e626935d 100644 --- a/src/werkzeug/formparser.py +++ b/src/werkzeug/formparser.py @@ -9,8 +9,24 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import re import codecs +import re +from functools import update_wrapper +from itertools import chain +from itertools import repeat +from itertools import tee + +from ._compat import BytesIO +from ._compat import text_type +from ._compat import to_native +from .datastructures import FileStorage +from .datastructures import Headers +from .datastructures import MultiDict +from .http import parse_options_header +from .urls import url_decode_stream +from .wsgi import get_content_length +from .wsgi import get_input_stream +from .wsgi import make_line_iter # there are some platforms where SpooledTemporaryFile is not available. # In that case we need to provide a fallback. @@ -18,45 +34,43 @@ try: from tempfile import SpooledTemporaryFile except ImportError: from tempfile import TemporaryFile - SpooledTemporaryFile = None - -from itertools import chain, repeat, tee -from functools import update_wrapper -from werkzeug._compat import to_native, text_type, BytesIO -from werkzeug.urls import url_decode_stream -from werkzeug.wsgi import make_line_iter, \ - get_input_stream, get_content_length -from werkzeug.datastructures import Headers, FileStorage, MultiDict -from werkzeug.http import parse_options_header + SpooledTemporaryFile = None #: an iterator that yields empty strings -_empty_string_iter = repeat('') +_empty_string_iter = repeat("") #: a regular expression for multipart boundaries -_multipart_boundary_re = re.compile('^[ -~]{0,200}[!-~]$') +_multipart_boundary_re = re.compile("^[ -~]{0,200}[!-~]$") #: supported http encodings that are also available in python we support #: for multipart messages. -_supported_multipart_encodings = frozenset(['base64', 'quoted-printable']) +_supported_multipart_encodings = frozenset(["base64", "quoted-printable"]) -def default_stream_factory(total_content_length, filename, content_type, - content_length=None): +def default_stream_factory( + total_content_length, filename, content_type, content_length=None +): """The stream factory that is used per default.""" max_size = 1024 * 500 if SpooledTemporaryFile is not None: - return SpooledTemporaryFile(max_size=max_size, mode='wb+') + return SpooledTemporaryFile(max_size=max_size, mode="wb+") if total_content_length is None or total_content_length > max_size: - return TemporaryFile('wb+') + return TemporaryFile("wb+") return BytesIO() -def parse_form_data(environ, stream_factory=None, charset='utf-8', - errors='replace', max_form_memory_size=None, - max_content_length=None, cls=None, - silent=True): +def parse_form_data( + environ, + stream_factory=None, + charset="utf-8", + errors="replace", + max_form_memory_size=None, + max_content_length=None, + cls=None, + silent=True, +): """Parse the form data in the environ and return it as tuple in the form ``(stream, form, files)``. You should only call this method if the transport method is `POST`, `PUT`, or `PATCH`. @@ -97,9 +111,15 @@ def parse_form_data(environ, stream_factory=None, charset='utf-8', :param silent: If set to False parsing errors will not be caught. :return: A tuple in the form ``(stream, form, files)``. """ - return FormDataParser(stream_factory, charset, errors, - max_form_memory_size, max_content_length, - cls, silent).parse_from_environ(environ) + return FormDataParser( + stream_factory, + charset, + errors, + max_form_memory_size, + max_content_length, + cls, + silent, + ).parse_from_environ(environ) def exhaust_stream(f): @@ -109,7 +129,7 @@ def exhaust_stream(f): try: return f(self, stream, *args, **kwargs) finally: - exhaust = getattr(stream, 'exhaust', None) + exhaust = getattr(stream, "exhaust", None) if exhaust is not None: exhaust() else: @@ -117,11 +137,11 @@ def exhaust_stream(f): chunk = stream.read(1024 * 64) if not chunk: break + return update_wrapper(wrapper, f) class FormDataParser(object): - """This class implements parsing of form data for Werkzeug. By itself it can parse multipart and url encoded form data. It can be subclassed and extended but for most mimetypes it is a better idea to use the @@ -149,10 +169,16 @@ class FormDataParser(object): :param silent: If set to False parsing errors will not be caught. """ - def __init__(self, stream_factory=None, charset='utf-8', - errors='replace', max_form_memory_size=None, - max_content_length=None, cls=None, - silent=True): + def __init__( + self, + stream_factory=None, + charset="utf-8", + errors="replace", + max_form_memory_size=None, + max_content_length=None, + cls=None, + silent=True, + ): if stream_factory is None: stream_factory = default_stream_factory self.stream_factory = stream_factory @@ -174,11 +200,10 @@ class FormDataParser(object): :param environ: the WSGI environment to be used for parsing. :return: A tuple in the form ``(stream, form, files)``. """ - content_type = environ.get('CONTENT_TYPE', '') + content_type = environ.get("CONTENT_TYPE", "") content_length = get_content_length(environ) mimetype, options = parse_options_header(content_type) - return self.parse(get_input_stream(environ), mimetype, - content_length, options) + return self.parse(get_input_stream(environ), mimetype, content_length, options) def parse(self, stream, mimetype, content_length, options=None): """Parses the information from the given stream, mimetype, @@ -191,9 +216,11 @@ class FormDataParser(object): the multipart boundary for instance) :return: A tuple in the form ``(stream, form, files)``. """ - if self.max_content_length is not None and \ - content_length is not None and \ - content_length > self.max_content_length: + if ( + self.max_content_length is not None + and content_length is not None + and content_length > self.max_content_length + ): raise exceptions.RequestEntityTooLarge() if options is None: options = {} @@ -201,8 +228,7 @@ class FormDataParser(object): parse_func = self.get_parse_func(mimetype, options) if parse_func is not None: try: - return parse_func(self, stream, mimetype, - content_length, options) + return parse_func(self, stream, mimetype, content_length, options) except ValueError: if not self.silent: raise @@ -211,32 +237,37 @@ class FormDataParser(object): @exhaust_stream def _parse_multipart(self, stream, mimetype, content_length, options): - parser = MultiPartParser(self.stream_factory, self.charset, self.errors, - max_form_memory_size=self.max_form_memory_size, - cls=self.cls) - boundary = options.get('boundary') + parser = MultiPartParser( + self.stream_factory, + self.charset, + self.errors, + max_form_memory_size=self.max_form_memory_size, + cls=self.cls, + ) + boundary = options.get("boundary") if boundary is None: - raise ValueError('Missing boundary') + raise ValueError("Missing boundary") if isinstance(boundary, text_type): - boundary = boundary.encode('ascii') + boundary = boundary.encode("ascii") form, files = parser.parse(stream, boundary, content_length) return stream, form, files @exhaust_stream def _parse_urlencoded(self, stream, mimetype, content_length, options): - if self.max_form_memory_size is not None and \ - content_length is not None and \ - content_length > self.max_form_memory_size: + if ( + self.max_form_memory_size is not None + and content_length is not None + and content_length > self.max_form_memory_size + ): raise exceptions.RequestEntityTooLarge() - form = url_decode_stream(stream, self.charset, - errors=self.errors, cls=self.cls) + form = url_decode_stream(stream, self.charset, errors=self.errors, cls=self.cls) return stream, form, self.cls() #: mapping of mimetypes to parsing functions parse_functions = { - 'multipart/form-data': _parse_multipart, - 'application/x-www-form-urlencoded': _parse_urlencoded, - 'application/x-url-encoded': _parse_urlencoded + "multipart/form-data": _parse_multipart, + "application/x-www-form-urlencoded": _parse_urlencoded, + "application/x-url-encoded": _parse_urlencoded, } @@ -249,9 +280,9 @@ def _line_parse(line): """Removes line ending characters and returns a tuple (`stripped_line`, `is_terminated`). """ - if line[-2:] in ['\r\n', b'\r\n']: + if line[-2:] in ["\r\n", b"\r\n"]: return line[:-2], True - elif line[-1:] in ['\r', '\n', b'\r', b'\n']: + elif line[-1:] in ["\r", "\n", b"\r", b"\n"]: return line[:-1], True return line, False @@ -270,14 +301,14 @@ def parse_multipart_headers(iterable): line = to_native(line) line, line_terminated = _line_parse(line) if not line_terminated: - raise ValueError('unexpected end of line in multipart header') + raise ValueError("unexpected end of line in multipart header") if not line: break - elif line[0] in ' \t' and result: + elif line[0] in " \t" and result: key, value = result[-1] - result[-1] = (key, value + '\n ' + line[1:]) + result[-1] = (key, value + "\n " + line[1:]) else: - parts = line.split(':', 1) + parts = line.split(":", 1) if len(parts) == 2: result.append((parts[0].strip(), parts[1].strip())) @@ -286,28 +317,36 @@ def parse_multipart_headers(iterable): return Headers(result) -_begin_form = 'begin_form' -_begin_file = 'begin_file' -_cont = 'cont' -_end = 'end' +_begin_form = "begin_form" +_begin_file = "begin_file" +_cont = "cont" +_end = "end" class MultiPartParser(object): - - def __init__(self, stream_factory=None, charset='utf-8', errors='replace', - max_form_memory_size=None, cls=None, buffer_size=64 * 1024): + def __init__( + self, + stream_factory=None, + charset="utf-8", + errors="replace", + max_form_memory_size=None, + cls=None, + buffer_size=64 * 1024, + ): self.charset = charset self.errors = errors self.max_form_memory_size = max_form_memory_size - self.stream_factory = default_stream_factory if stream_factory is None else stream_factory + self.stream_factory = ( + default_stream_factory if stream_factory is None else stream_factory + ) self.cls = MultiDict if cls is None else cls # make sure the buffer size is divisible by four so that we can base64 # decode chunk by chunk - assert buffer_size % 4 == 0, 'buffer size has to be divisible by 4' + assert buffer_size % 4 == 0, "buffer size has to be divisible by 4" # also the buffer size has to be at least 1024 bytes long or long headers # will freak out the system - assert buffer_size >= 1024, 'buffer size has to be at least 1KB' + assert buffer_size >= 1024, "buffer size has to be at least 1KB" self.buffer_size = buffer_size @@ -316,8 +355,8 @@ class MultiPartParser(object): uploaded. This function strips the full path if it thinks the filename is Windows-like absolute. """ - if filename[1:3] == ':\\' or filename[:2] == '\\\\': - return filename.split('\\')[-1] + if filename[1:3] == ":\\" or filename[:2] == "\\\\": + return filename.split("\\")[-1] return filename def _find_terminator(self, iterator): @@ -331,36 +370,39 @@ class MultiPartParser(object): line = line.strip() if line: return line - return b'' + return b"" def fail(self, message): raise ValueError(message) def get_part_encoding(self, headers): - transfer_encoding = headers.get('content-transfer-encoding') - if transfer_encoding is not None and \ - transfer_encoding in _supported_multipart_encodings: + transfer_encoding = headers.get("content-transfer-encoding") + if ( + transfer_encoding is not None + and transfer_encoding in _supported_multipart_encodings + ): return transfer_encoding def get_part_charset(self, headers): # Figure out input charset for current part - content_type = headers.get('content-type') + content_type = headers.get("content-type") if content_type: mimetype, ct_params = parse_options_header(content_type) - return ct_params.get('charset', self.charset) + return ct_params.get("charset", self.charset) return self.charset def start_file_streaming(self, filename, headers, total_content_length): if isinstance(filename, bytes): filename = filename.decode(self.charset, self.errors) filename = self._fix_ie_filename(filename) - content_type = headers.get('content-type') + content_type = headers.get("content-type") try: - content_length = int(headers['content-length']) + content_length = int(headers["content-length"]) except (KeyError, ValueError): content_length = 0 - container = self.stream_factory(total_content_length, content_type, - filename, content_length) + container = self.stream_factory( + total_content_length, content_type, filename, content_length + ) return filename, container def in_memory_threshold_reached(self, bytes): @@ -368,15 +410,15 @@ class MultiPartParser(object): def validate_boundary(self, boundary): if not boundary: - self.fail('Missing boundary') + self.fail("Missing boundary") if not is_valid_multipart_boundary(boundary): - self.fail('Invalid boundary: %s' % boundary) + self.fail("Invalid boundary: %s" % boundary) if len(boundary) > self.buffer_size: # pragma: no cover # this should never happen because we check for a minimum size # of 1024 and boundaries may not be longer than 200. The only # situation when this happens is for non debug builds where # the assert is skipped. - self.fail('Boundary longer than buffer size') + self.fail("Boundary longer than buffer size") def parse_lines(self, file, boundary, content_length, cap_at_buffer=True): """Generate parts of @@ -389,31 +431,36 @@ class MultiPartParser(object): parts = ( begin_form cont* end | begin_file cont* end )* """ - next_part = b'--' + boundary - last_part = next_part + b'--' - - iterator = chain(make_line_iter(file, limit=content_length, - buffer_size=self.buffer_size, - cap_at_buffer=cap_at_buffer), - _empty_string_iter) + next_part = b"--" + boundary + last_part = next_part + b"--" + + iterator = chain( + make_line_iter( + file, + limit=content_length, + buffer_size=self.buffer_size, + cap_at_buffer=cap_at_buffer, + ), + _empty_string_iter, + ) terminator = self._find_terminator(iterator) if terminator == last_part: return elif terminator != next_part: - self.fail('Expected boundary at start of multipart data') + self.fail("Expected boundary at start of multipart data") while terminator != last_part: headers = parse_multipart_headers(iterator) - disposition = headers.get('content-disposition') + disposition = headers.get("content-disposition") if disposition is None: - self.fail('Missing Content-Disposition header') + self.fail("Missing Content-Disposition header") disposition, extra = parse_options_header(disposition) transfer_encoding = self.get_part_encoding(headers) - name = extra.get('name') - filename = extra.get('filename') + name = extra.get("name") + filename = extra.get("filename") # if no content type is given we stream into memory. A list is # used as a temporary container. @@ -425,29 +472,29 @@ class MultiPartParser(object): else: yield _begin_file, (headers, name, filename) - buf = b'' + buf = b"" for line in iterator: if not line: - self.fail('unexpected end of stream') + self.fail("unexpected end of stream") - if line[:2] == b'--': + if line[:2] == b"--": terminator = line.rstrip() if terminator in (next_part, last_part): break if transfer_encoding is not None: - if transfer_encoding == 'base64': - transfer_encoding = 'base64_codec' + if transfer_encoding == "base64": + transfer_encoding = "base64_codec" try: line = codecs.decode(line, transfer_encoding) except Exception: - self.fail('could not decode transfer encoded chunk') + self.fail("could not decode transfer encoded chunk") # we have something in the buffer from the last iteration. # this is usually a newline delimiter. if buf: yield _cont, buf - buf = b'' + buf = b"" # If the line ends with windows CRLF we write everything except # the last two bytes. In all other cases however we write @@ -458,8 +505,8 @@ class MultiPartParser(object): # truncate the stream. However we do have to make sure that # if something else than a newline is in there we write it # out. - if line[-2:] == b'\r\n': - buf = b'\r\n' + if line[-2:] == b"\r\n": + buf = b"\r\n" cutoff = -2 else: buf = line[-1:] @@ -467,12 +514,12 @@ class MultiPartParser(object): yield _cont, line[:cutoff] else: # pragma: no cover - raise ValueError('unexpected end of part') + raise ValueError("unexpected end of part") # if we have a leftover in the buffer that is not a newline # character we have to flush it, otherwise we will chop of # certain values. - if buf not in (b'', b'\r', b'\n', b'\r\n'): + if buf not in (b"", b"\r", b"\n", b"\r\n"): yield _cont, buf yield _end, None @@ -489,7 +536,8 @@ class MultiPartParser(object): is_file = True guard_memory = False filename, container = self.start_file_streaming( - filename, headers, content_length) + filename, headers, content_length + ) _write = container.write elif ellt == _begin_form: @@ -512,21 +560,24 @@ class MultiPartParser(object): elif ellt == _end: if is_file: container.seek(0) - yield ('file', - (name, FileStorage(container, filename, name, - headers=headers))) + yield ( + "file", + (name, FileStorage(container, filename, name, headers=headers)), + ) else: part_charset = self.get_part_charset(headers) - yield ('form', - (name, b''.join(container).decode( - part_charset, self.errors))) + yield ( + "form", + (name, b"".join(container).decode(part_charset, self.errors)), + ) def parse(self, file, boundary, content_length): formstream, filestream = tee( - self.parse_parts(file, boundary, content_length), 2) - form = (p[1] for p in formstream if p[0] == 'form') - files = (p[1] for p in filestream if p[0] == 'file') + self.parse_parts(file, boundary, content_length), 2 + ) + form = (p[1] for p in formstream if p[0] == "form") + files = (p[1] for p in filestream if p[0] == "file") return self.cls(form), self.cls(files) -from werkzeug import exceptions +from . import exceptions diff --git a/src/werkzeug/http.py b/src/werkzeug/http.py index 9e18ea97..af320075 100644 --- a/src/werkzeug/http.py +++ b/src/werkzeug/http.py @@ -16,54 +16,68 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ +import base64 import re import warnings -from time import time, gmtime +from datetime import datetime +from datetime import timedelta +from hashlib import md5 +from time import gmtime +from time import time + +from ._compat import integer_types +from ._compat import iteritems +from ._compat import PY2 +from ._compat import string_types +from ._compat import text_type +from ._compat import to_bytes +from ._compat import to_unicode +from ._compat import try_coerce_native +from ._internal import _cookie_parse_impl +from ._internal import _cookie_quote +from ._internal import _make_cookie_domain + try: from email.utils import parsedate_tz -except ImportError: # pragma: no cover +except ImportError: from email.Utils import parsedate_tz + try: from urllib.request import parse_http_list as _parse_list_header from urllib.parse import unquote_to_bytes as _unquote -except ImportError: # pragma: no cover - from urllib2 import parse_http_list as _parse_list_header, \ - unquote as _unquote -from datetime import datetime, timedelta -from hashlib import md5 -import base64 +except ImportError: + from urllib2 import parse_http_list as _parse_list_header + from urllib2 import unquote as _unquote -from werkzeug._internal import _cookie_quote, _make_cookie_domain, \ - _cookie_parse_impl -from werkzeug._compat import to_unicode, iteritems, text_type, \ - string_types, try_coerce_native, to_bytes, PY2, \ - integer_types - - -_cookie_charset = 'latin1' -_basic_auth_charset = 'utf-8' +_cookie_charset = "latin1" +_basic_auth_charset = "utf-8" # for explanation of "media-range", etc. see Sections 5.3.{1,2} of RFC 7231 _accept_re = re.compile( - r'''( # media-range capturing-parenthesis - [^\s;,]+ # type/subtype - (?:[ \t]*;[ \t]* # ";" - (?: # parameter non-capturing-parenthesis - [^\s;,q][^\s;,]* # token that doesn't start with "q" - | # or - q[^\s;,=][^\s;,]* # token that is more than just "q" - ) - )* # zero or more parameters - ) # end of media-range - (?:[ \t]*;[ \t]*q= # weight is a "q" parameter - (\d*(?:\.\d+)?) # qvalue capturing-parentheses - [^,]* # "extension" accept params: who cares? - )? # accept params are optional - ''', re.VERBOSE) -_token_chars = frozenset("!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" - '^_`abcdefghijklmnopqrstuvwxyz|~') + r""" + ( # media-range capturing-parenthesis + [^\s;,]+ # type/subtype + (?:[ \t]*;[ \t]* # ";" + (?: # parameter non-capturing-parenthesis + [^\s;,q][^\s;,]* # token that doesn't start with "q" + | # or + q[^\s;,=][^\s;,]* # token that is more than just "q" + ) + )* # zero or more parameters + ) # end of media-range + (?:[ \t]*;[ \t]*q= # weight is a "q" parameter + (\d*(?:\.\d+)?) # qvalue capturing-parentheses + [^,]* # "extension" accept params: who cares? + )? # accept params are optional + """, + re.VERBOSE, +) +_token_chars = frozenset( + "!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ^_`abcdefghijklmnopqrstuvwxyz|~" +) _etag_re = re.compile(r'([Ww]/)?(?:"(.*?)"|(.*?))(?:\s*,\s*|$)') -_unsafe_header_chars = set('()<>@,;:\"/[]?={} \t') -_option_header_piece_re = re.compile(r''' +_unsafe_header_chars = set('()<>@,;:"/[]?={} \t') +_option_header_piece_re = re.compile( + r""" ;\s*,?\s* # newlines were replaced with commas (?P<key> "[^"\\]*(?:\\.[^"\\]*)*" # quoted string @@ -89,100 +103,116 @@ _option_header_piece_re = re.compile(r''' )? )? \s* -''', flags=re.VERBOSE) -_option_header_start_mime_type = re.compile(r',\s*([^;,\s]+)([;,]\s*.+)?') - -_entity_headers = frozenset([ - 'allow', 'content-encoding', 'content-language', 'content-length', - 'content-location', 'content-md5', 'content-range', 'content-type', - 'expires', 'last-modified' -]) -_hop_by_hop_headers = frozenset([ - 'connection', 'keep-alive', 'proxy-authenticate', - 'proxy-authorization', 'te', 'trailer', 'transfer-encoding', - 'upgrade' -]) + """, + flags=re.VERBOSE, +) +_option_header_start_mime_type = re.compile(r",\s*([^;,\s]+)([;,]\s*.+)?") + +_entity_headers = frozenset( + [ + "allow", + "content-encoding", + "content-language", + "content-length", + "content-location", + "content-md5", + "content-range", + "content-type", + "expires", + "last-modified", + ] +) +_hop_by_hop_headers = frozenset( + [ + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailer", + "transfer-encoding", + "upgrade", + ] +) HTTP_STATUS_CODES = { - 100: 'Continue', - 101: 'Switching Protocols', - 102: 'Processing', - 200: 'OK', - 201: 'Created', - 202: 'Accepted', - 203: 'Non Authoritative Information', - 204: 'No Content', - 205: 'Reset Content', - 206: 'Partial Content', - 207: 'Multi Status', - 226: 'IM Used', # see RFC 3229 - 300: 'Multiple Choices', - 301: 'Moved Permanently', - 302: 'Found', - 303: 'See Other', - 304: 'Not Modified', - 305: 'Use Proxy', - 307: 'Temporary Redirect', - 308: 'Permanent Redirect', - 400: 'Bad Request', - 401: 'Unauthorized', - 402: 'Payment Required', # unused - 403: 'Forbidden', - 404: 'Not Found', - 405: 'Method Not Allowed', - 406: 'Not Acceptable', - 407: 'Proxy Authentication Required', - 408: 'Request Timeout', - 409: 'Conflict', - 410: 'Gone', - 411: 'Length Required', - 412: 'Precondition Failed', - 413: 'Request Entity Too Large', - 414: 'Request URI Too Long', - 415: 'Unsupported Media Type', - 416: 'Requested Range Not Satisfiable', - 417: 'Expectation Failed', - 418: 'I\'m a teapot', # see RFC 2324 - 421: 'Misdirected Request', # see RFC 7540 - 422: 'Unprocessable Entity', - 423: 'Locked', - 424: 'Failed Dependency', - 426: 'Upgrade Required', - 428: 'Precondition Required', # see RFC 6585 - 429: 'Too Many Requests', - 431: 'Request Header Fields Too Large', - 449: 'Retry With', # proprietary MS extension - 451: 'Unavailable For Legal Reasons', - 500: 'Internal Server Error', - 501: 'Not Implemented', - 502: 'Bad Gateway', - 503: 'Service Unavailable', - 504: 'Gateway Timeout', - 505: 'HTTP Version Not Supported', - 507: 'Insufficient Storage', - 510: 'Not Extended' + 100: "Continue", + 101: "Switching Protocols", + 102: "Processing", + 200: "OK", + 201: "Created", + 202: "Accepted", + 203: "Non Authoritative Information", + 204: "No Content", + 205: "Reset Content", + 206: "Partial Content", + 207: "Multi Status", + 226: "IM Used", # see RFC 3229 + 300: "Multiple Choices", + 301: "Moved Permanently", + 302: "Found", + 303: "See Other", + 304: "Not Modified", + 305: "Use Proxy", + 307: "Temporary Redirect", + 308: "Permanent Redirect", + 400: "Bad Request", + 401: "Unauthorized", + 402: "Payment Required", # unused + 403: "Forbidden", + 404: "Not Found", + 405: "Method Not Allowed", + 406: "Not Acceptable", + 407: "Proxy Authentication Required", + 408: "Request Timeout", + 409: "Conflict", + 410: "Gone", + 411: "Length Required", + 412: "Precondition Failed", + 413: "Request Entity Too Large", + 414: "Request URI Too Long", + 415: "Unsupported Media Type", + 416: "Requested Range Not Satisfiable", + 417: "Expectation Failed", + 418: "I'm a teapot", # see RFC 2324 + 421: "Misdirected Request", # see RFC 7540 + 422: "Unprocessable Entity", + 423: "Locked", + 424: "Failed Dependency", + 426: "Upgrade Required", + 428: "Precondition Required", # see RFC 6585 + 429: "Too Many Requests", + 431: "Request Header Fields Too Large", + 449: "Retry With", # proprietary MS extension + 451: "Unavailable For Legal Reasons", + 500: "Internal Server Error", + 501: "Not Implemented", + 502: "Bad Gateway", + 503: "Service Unavailable", + 504: "Gateway Timeout", + 505: "HTTP Version Not Supported", + 507: "Insufficient Storage", + 510: "Not Extended", } def wsgi_to_bytes(data): - """coerce wsgi unicode represented bytes to real ones - - """ + """coerce wsgi unicode represented bytes to real ones""" if isinstance(data, bytes): return data - return data.encode('latin1') # XXX: utf8 fallback? + return data.encode("latin1") # XXX: utf8 fallback? def bytes_to_wsgi(data): - assert isinstance(data, bytes), 'data must be bytes' + assert isinstance(data, bytes), "data must be bytes" if isinstance(data, str): return data else: - return data.decode('latin1') + return data.decode("latin1") -def quote_header_value(value, extra_chars='', allow_token=True): +def quote_header_value(value, extra_chars="", allow_token=True): """Quote a header value if necessary. .. versionadded:: 0.5 @@ -199,7 +229,7 @@ def quote_header_value(value, extra_chars='', allow_token=True): token_chars = _token_chars | set(extra_chars) if set(value).issubset(token_chars): return value - return '"%s"' % value.replace('\\', '\\\\').replace('"', '\\"') + return '"%s"' % value.replace("\\", "\\\\").replace('"', '\\"') def unquote_header_value(value, is_filename=False): @@ -223,8 +253,8 @@ def unquote_header_value(value, is_filename=False): # replace sequence below on a UNC path has the effect of turning # the leading double slash into a single slash and then # _fix_ie_filename() doesn't work correctly. See #458. - if not is_filename or value[:2] != '\\\\': - return value.replace('\\\\', '\\').replace('\\"', '"') + if not is_filename or value[:2] != "\\\\": + return value.replace("\\\\", "\\").replace('\\"', '"') return value @@ -241,8 +271,8 @@ def dump_options_header(header, options): if value is None: segments.append(key) else: - segments.append('%s=%s' % (key, quote_header_value(value))) - return '; '.join(segments) + segments.append("%s=%s" % (key, quote_header_value(value))) + return "; ".join(segments) def dump_header(iterable, allow_token=True): @@ -266,14 +296,12 @@ def dump_header(iterable, allow_token=True): if value is None: items.append(key) else: - items.append('%s=%s' % ( - key, - quote_header_value(value, allow_token=allow_token) - )) + items.append( + "%s=%s" % (key, quote_header_value(value, allow_token=allow_token)) + ) else: - items = [quote_header_value(x, allow_token=allow_token) - for x in iterable] - return ', '.join(items) + items = [quote_header_value(x, allow_token=allow_token) for x in iterable] + return ", ".join(items) def parse_list_header(value): @@ -337,10 +365,10 @@ def parse_dict_header(value, cls=dict): # XXX: validate value = bytes_to_wsgi(value) for item in _parse_list_header(value): - if '=' not in item: + if "=" not in item: result[item] = None continue - name, value = item.split('=', 1) + name, value = item.split("=", 1) if value[:1] == value[-1:] == '"': value = unquote_header_value(value[1:-1]) result[name] = value @@ -369,7 +397,7 @@ def parse_options_header(value, multiple=False): if multiple=True """ if not value: - return '', {} + return "", {} result = [] @@ -400,9 +428,7 @@ def parse_options_header(value, multiple=False): continued_encoding = encoding option = unquote_header_value(option) if option_value is not None: - option_value = unquote_header_value( - option_value, - option == 'filename') + option_value = unquote_header_value(option_value, option == "filename") if encoding is not None: option_value = _unquote(option_value).decode(encoding) if count: @@ -412,13 +438,13 @@ def parse_options_header(value, multiple=False): options[option] = options.get(option, "") + option_value else: options[option] = option_value - rest = rest[optmatch.end():] + rest = rest[optmatch.end() :] result.append(options) if multiple is False: return tuple(result) value = rest - return tuple(result) if result else ('', {}) + return tuple(result) if result else ("", {}) def parse_accept_header(value, cls=None): @@ -525,26 +551,27 @@ def parse_authorization_header(value): auth_type = auth_type.lower() except ValueError: return - if auth_type == b'basic': + if auth_type == b"basic": try: - username, password = base64.b64decode(auth_info).split(b':', 1) + username, password = base64.b64decode(auth_info).split(b":", 1) except Exception: return return Authorization( - 'basic', { - 'username': to_unicode(username, _basic_auth_charset), - 'password': to_unicode(password, _basic_auth_charset) - } + "basic", + { + "username": to_unicode(username, _basic_auth_charset), + "password": to_unicode(password, _basic_auth_charset), + }, ) - elif auth_type == b'digest': + elif auth_type == b"digest": auth_map = parse_dict_header(auth_info) - for key in 'username', 'realm', 'nonce', 'uri', 'response': + for key in "username", "realm", "nonce", "uri", "response": if key not in auth_map: return - if 'qop' in auth_map: - if not auth_map.get('nc') or not auth_map.get('cnonce'): + if "qop" in auth_map: + if not auth_map.get("nc") or not auth_map.get("cnonce"): return - return Authorization('digest', auth_map) + return Authorization("digest", auth_map) def parse_www_authenticate_header(value, on_update=None): @@ -564,8 +591,7 @@ def parse_www_authenticate_header(value, on_update=None): auth_type = auth_type.lower() except (ValueError, AttributeError): return WWWAuthenticate(value.strip().lower(), on_update=on_update) - return WWWAuthenticate(auth_type, parse_dict_header(auth_info), - on_update) + return WWWAuthenticate(auth_type, parse_dict_header(auth_info), on_update) def parse_if_range_header(value): @@ -591,19 +617,19 @@ def parse_range_header(value, make_inclusive=True): .. versionadded:: 0.7 """ - if not value or '=' not in value: + if not value or "=" not in value: return None ranges = [] last_end = 0 - units, rng = value.split('=', 1) + units, rng = value.split("=", 1) units = units.strip().lower() - for item in rng.split(','): + for item in rng.split(","): item = item.strip() - if '-' not in item: + if "-" not in item: return None - if item.startswith('-'): + if item.startswith("-"): if last_end < 0: return None try: @@ -612,8 +638,8 @@ def parse_range_header(value, make_inclusive=True): return None end = None last_end = -1 - elif '-' in item: - begin, end = item.split('-', 1) + elif "-" in item: + begin, end = item.split("-", 1) begin = begin.strip() end = end.strip() if not begin.isdigit(): @@ -650,26 +676,26 @@ def parse_content_range_header(value, on_update=None): if value is None: return None try: - units, rangedef = (value or '').strip().split(None, 1) + units, rangedef = (value or "").strip().split(None, 1) except ValueError: return None - if '/' not in rangedef: + if "/" not in rangedef: return None - rng, length = rangedef.split('/', 1) - if length == '*': + rng, length = rangedef.split("/", 1) + if length == "*": length = None elif length.isdigit(): length = int(length) else: return None - if rng == '*': + if rng == "*": return ContentRange(units, None, None, length, on_update=on_update) - elif '-' not in rng: + elif "-" not in rng: return None - start, stop = rng.split('-', 1) + start, stop = rng.split("-", 1) try: start = int(start) stop = int(stop) + 1 @@ -687,10 +713,10 @@ def quote_etag(etag, weak=False): :param weak: set to `True` to tag it "weak". """ if '"' in etag: - raise ValueError('invalid etag') + raise ValueError("invalid etag") etag = '"%s"' % etag if weak: - etag = 'W/' + etag + etag = "W/" + etag return etag @@ -709,7 +735,7 @@ def unquote_etag(etag): return None, None etag = etag.strip() weak = False - if etag.startswith(('W/', 'w/')): + if etag.startswith(("W/", "w/")): weak = True etag = etag[2:] if etag[:1] == etag[-1:] == '"': @@ -734,7 +760,7 @@ def parse_etags(value): if match is None: break is_weak, quoted, raw = match.groups() - if raw == '*': + if raw == "*": return ETags(star_tag=True) elif quoted: raw = quoted @@ -778,8 +804,7 @@ def parse_date(value): year += 2000 elif year >= 69 and year <= 99: year += 1900 - return datetime(*((year,) + t[1:7])) - \ - timedelta(seconds=t[-1] or 0) + return datetime(*((year,) + t[1:7])) - timedelta(seconds=t[-1] or 0) except (ValueError, OverflowError): return None @@ -792,12 +817,29 @@ def _dump_date(d, delim): d = d.utctimetuple() elif isinstance(d, (integer_types, float)): d = gmtime(d) - return '%s, %02d%s%s%s%s %02d:%02d:%02d GMT' % ( - ('Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun')[d.tm_wday], - d.tm_mday, delim, - ('Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', - 'Oct', 'Nov', 'Dec')[d.tm_mon - 1], - delim, str(d.tm_year), d.tm_hour, d.tm_min, d.tm_sec + return "%s, %02d%s%s%s%s %02d:%02d:%02d GMT" % ( + ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun")[d.tm_wday], + d.tm_mday, + delim, + ( + "Jan", + "Feb", + "Mar", + "Apr", + "May", + "Jun", + "Jul", + "Aug", + "Sep", + "Oct", + "Nov", + "Dec", + )[d.tm_mon - 1], + delim, + str(d.tm_year), + d.tm_hour, + d.tm_min, + d.tm_sec, ) @@ -813,7 +855,7 @@ def cookie_date(expires=None): :param expires: If provided that date is used, otherwise the current. """ - return _dump_date(expires, '-') + return _dump_date(expires, "-") def http_date(timestamp=None): @@ -827,7 +869,7 @@ def http_date(timestamp=None): :param timestamp: If provided that date is used, otherwise the current. """ - return _dump_date(timestamp, ' ') + return _dump_date(timestamp, " ") def parse_age(value=None): @@ -868,13 +910,14 @@ def dump_age(age=None): age = int(age) if age < 0: - raise ValueError('age cannot be negative') + raise ValueError("age cannot be negative") return str(age) -def is_resource_modified(environ, etag=None, data=None, last_modified=None, - ignore_if_range=True): +def is_resource_modified( + environ, etag=None, data=None, last_modified=None, ignore_if_range=True +): """Convenience method for conditional requests. :param environ: the WSGI environment of the request to be checked. @@ -889,8 +932,8 @@ def is_resource_modified(environ, etag=None, data=None, last_modified=None, if etag is None and data is not None: etag = generate_etag(data) elif data is not None: - raise TypeError('both data and etag given') - if environ['REQUEST_METHOD'] not in ('GET', 'HEAD'): + raise TypeError("both data and etag given") + if environ["REQUEST_METHOD"] not in ("GET", "HEAD"): return False unmodified = False @@ -903,16 +946,16 @@ def is_resource_modified(environ, etag=None, data=None, last_modified=None, last_modified = last_modified.replace(microsecond=0) if_range = None - if not ignore_if_range and 'HTTP_RANGE' in environ: + if not ignore_if_range and "HTTP_RANGE" in environ: # https://tools.ietf.org/html/rfc7233#section-3.2 # A server MUST ignore an If-Range header field received in a request # that does not contain a Range header field. - if_range = parse_if_range_header(environ.get('HTTP_IF_RANGE')) + if_range = parse_if_range_header(environ.get("HTTP_IF_RANGE")) if if_range is not None and if_range.date is not None: modified_since = if_range.date else: - modified_since = parse_date(environ.get('HTTP_IF_MODIFIED_SINCE')) + modified_since = parse_date(environ.get("HTTP_IF_MODIFIED_SINCE")) if modified_since and last_modified and last_modified <= modified_since: unmodified = True @@ -922,7 +965,7 @@ def is_resource_modified(environ, etag=None, data=None, last_modified=None, if if_range is not None and if_range.etag is not None: unmodified = parse_etags(if_range.etag).contains(etag) else: - if_none_match = parse_etags(environ.get('HTTP_IF_NONE_MATCH')) + if_none_match = parse_etags(environ.get("HTTP_IF_NONE_MATCH")) if if_none_match: # https://tools.ietf.org/html/rfc7232#section-3.2 # "A recipient MUST use the weak comparison function when comparing @@ -932,14 +975,14 @@ def is_resource_modified(environ, etag=None, data=None, last_modified=None, # https://tools.ietf.org/html/rfc7232#section-3.1 # "Origin server MUST use the strong comparison function when # comparing entity-tags for If-Match" - if_match = parse_etags(environ.get('HTTP_IF_MATCH')) + if_match = parse_etags(environ.get("HTTP_IF_MATCH")) if if_match: unmodified = not if_match.is_strong(etag) return not unmodified -def remove_entity_headers(headers, allowed=('expires', 'content-location')): +def remove_entity_headers(headers, allowed=("expires", "content-location")): """Remove all entity headers from a list or :class:`Headers` object. This operation works in-place. `Expires` and `Content-Location` headers are by default not removed. The reason for this is :rfc:`2616` section @@ -953,8 +996,11 @@ def remove_entity_headers(headers, allowed=('expires', 'content-location')): they are entity headers. """ allowed = set(x.lower() for x in allowed) - headers[:] = [(key, value) for key, value in headers if - not is_entity_header(key) or key.lower() in allowed] + headers[:] = [ + (key, value) + for key, value in headers + if not is_entity_header(key) or key.lower() in allowed + ] def remove_hop_by_hop_headers(headers): @@ -965,8 +1011,9 @@ def remove_hop_by_hop_headers(headers): :param headers: a list or :class:`Headers` object. """ - headers[:] = [(key, value) for key, value in headers if - not is_hop_by_hop_header(key)] + headers[:] = [ + (key, value) for key, value in headers if not is_hop_by_hop_header(key) + ] def is_entity_header(header): @@ -991,7 +1038,7 @@ def is_hop_by_hop_header(header): return header.lower() in _hop_by_hop_headers -def parse_cookie(header, charset='utf-8', errors='replace', cls=None): +def parse_cookie(header, charset="utf-8", errors="replace", cls=None): """Parse a cookie. Either from a string or WSGI environ. Per default encoding errors are ignored. If you want a different behavior @@ -1011,16 +1058,16 @@ def parse_cookie(header, charset='utf-8', errors='replace', cls=None): used. """ if isinstance(header, dict): - header = header.get('HTTP_COOKIE', '') + header = header.get("HTTP_COOKIE", "") elif header is None: - header = '' + header = "" # If the value is an unicode string it's mangled through latin1. This # is done because on PEP 3333 on Python 3 all headers are assumed latin1 # which however is incorrect for cookies, which are sent in page encoding. # As a result we if isinstance(header, text_type): - header = header.encode('latin1', 'replace') + header = header.encode("latin1", "replace") if cls is None: cls = TypeConversionDict @@ -1036,10 +1083,20 @@ def parse_cookie(header, charset='utf-8', errors='replace', cls=None): return cls(_parse_pairs()) -def dump_cookie(key, value='', max_age=None, expires=None, path='/', - domain=None, secure=False, httponly=False, - charset='utf-8', sync_expires=True, max_size=4093, - samesite=None): +def dump_cookie( + key, + value="", + max_age=None, + expires=None, + path="/", + domain=None, + secure=False, + httponly=False, + charset="utf-8", + sync_expires=True, + max_size=4093, + samesite=None, +): """Creates a new Set-Cookie header without the ``Set-Cookie`` prefix The parameters are the same as in the cookie Morsel object in the Python standard library but it accepts unicode data, too. @@ -1098,21 +1155,23 @@ def dump_cookie(key, value='', max_age=None, expires=None, path='/', expires = to_bytes(cookie_date(time() + max_age)) samesite = samesite.title() if samesite else None - if samesite not in ('Strict', 'Lax', None): + if samesite not in ("Strict", "Lax", None): raise ValueError("invalid SameSite value; must be 'Strict', 'Lax' or None") - buf = [key + b'=' + _cookie_quote(value)] + buf = [key + b"=" + _cookie_quote(value)] # XXX: In theory all of these parameters that are not marked with `None` # should be quoted. Because stdlib did not quote it before I did not # want to introduce quoting there now. - for k, v, q in ((b'Domain', domain, True), - (b'Expires', expires, False,), - (b'Max-Age', max_age, False), - (b'Secure', secure, None), - (b'HttpOnly', httponly, None), - (b'Path', path, False), - (b'SameSite', samesite, False)): + for k, v, q in ( + (b"Domain", domain, True), + (b"Expires", expires, False), + (b"Max-Age", max_age, False), + (b"Secure", secure, None), + (b"HttpOnly", httponly, None), + (b"Path", path, False), + (b"SameSite", samesite, False), + ): if q is None: if v: buf.append(k) @@ -1126,15 +1185,15 @@ def dump_cookie(key, value='', max_age=None, expires=None, path='/', v = to_bytes(text_type(v), charset) if q: v = _cookie_quote(v) - tmp += b'=' + v + tmp += b"=" + v buf.append(bytes(tmp)) # The return value will be an incorrectly encoded latin1 header on # Python 3 for consistency with the headers object and a bytestring # on Python 2 because that's how the API makes more sense. - rv = b'; '.join(buf) + rv = b"; ".join(buf) if not PY2: - rv = rv.decode('latin1') + rv = rv.decode("latin1") # Warn if the final value of the cookie is less than the limit. If the # cookie is too large, then it may be silently ignored, which can be quite @@ -1145,16 +1204,16 @@ def dump_cookie(key, value='', max_age=None, expires=None, path='/', value_size = len(value) warnings.warn( 'The "{key}" cookie is too large: the value was {value_size} bytes' - ' but the header required {extra_size} extra bytes. The final size' - ' was {cookie_size} bytes but the limit is {max_size} bytes.' - ' Browsers may silently ignore cookies larger than this.'.format( + " but the header required {extra_size} extra bytes. The final size" + " was {cookie_size} bytes but the limit is {max_size} bytes." + " Browsers may silently ignore cookies larger than this.".format( key=key, value_size=value_size, extra_size=cookie_size - value_size, cookie_size=cookie_size, - max_size=max_size + max_size=max_size, ), - stacklevel=2 + stacklevel=2, ) return rv @@ -1177,19 +1236,23 @@ def is_byte_range_valid(start, stop, length): # circular dependency fun -from werkzeug.datastructures import Accept, HeaderSet, ETags, Authorization, \ - WWWAuthenticate, TypeConversionDict, IfRange, Range, ContentRange, \ - RequestCacheControl -from werkzeug.urls import iri_to_uri - +from .datastructures import Accept +from .datastructures import Authorization +from .datastructures import ContentRange +from .datastructures import ETags +from .datastructures import HeaderSet +from .datastructures import IfRange +from .datastructures import Range +from .datastructures import RequestCacheControl +from .datastructures import TypeConversionDict +from .datastructures import WWWAuthenticate +from .urls import iri_to_uri # DEPRECATED -from werkzeug.datastructures import ( - MIMEAccept as _MIMEAccept, - CharsetAccept as _CharsetAccept, - LanguageAccept as _LanguageAccept, - Headers as _Headers, -) +from .datastructures import CharsetAccept as _CharsetAccept +from .datastructures import Headers as _Headers +from .datastructures import LanguageAccept as _LanguageAccept +from .datastructures import MIMEAccept as _MIMEAccept class MIMEAccept(_MIMEAccept): diff --git a/src/werkzeug/local.py b/src/werkzeug/local.py index af98c305..9a6088cc 100644 --- a/src/werkzeug/local.py +++ b/src/werkzeug/local.py @@ -10,8 +10,10 @@ """ import copy from functools import update_wrapper -from werkzeug.wsgi import ClosingIterator -from werkzeug._compat import PY2, implements_bool + +from ._compat import implements_bool +from ._compat import PY2 +from .wsgi import ClosingIterator # since each thread has its own greenlet we can just use those as identifiers # for the context. If greenlets are not available we fall back to the @@ -49,11 +51,11 @@ def release_local(local): class Local(object): - __slots__ = ('__storage__', '__ident_func__') + __slots__ = ("__storage__", "__ident_func__") def __init__(self): - object.__setattr__(self, '__storage__', {}) - object.__setattr__(self, '__ident_func__', get_ident) + object.__setattr__(self, "__storage__", {}) + object.__setattr__(self, "__ident_func__", get_ident) def __iter__(self): return iter(self.__storage__.items()) @@ -87,7 +89,6 @@ class Local(object): class LocalStack(object): - """This class works similar to a :class:`Local` but keeps a stack of objects instead. This is best explained with an example:: @@ -124,7 +125,8 @@ class LocalStack(object): return self._local.__ident_func__ def _set__ident_func__(self, value): - object.__setattr__(self._local, '__ident_func__', value) + object.__setattr__(self._local, "__ident_func__", value) + __ident_func__ = property(_get__ident_func__, _set__ident_func__) del _get__ident_func__, _set__ident_func__ @@ -132,13 +134,14 @@ class LocalStack(object): def _lookup(): rv = self.top if rv is None: - raise RuntimeError('object unbound') + raise RuntimeError("object unbound") return rv + return LocalProxy(_lookup) def push(self, obj): """Pushes a new item to the stack""" - rv = getattr(self._local, 'stack', None) + rv = getattr(self._local, "stack", None) if rv is None: self._local.stack = rv = [] rv.append(obj) @@ -148,7 +151,7 @@ class LocalStack(object): """Removes the topmost item from the stack, will return the old value or `None` if the stack was already empty. """ - stack = getattr(self._local, 'stack', None) + stack = getattr(self._local, "stack", None) if stack is None: return None elif len(stack) == 1: @@ -169,7 +172,6 @@ class LocalStack(object): class LocalManager(object): - """Local objects cannot manage themselves. For that you need a local manager. You can pass a local manager multiple locals or add them later by appending them to `manager.locals`. Every time the manager cleans up, @@ -196,7 +198,7 @@ class LocalManager(object): if ident_func is not None: self.ident_func = ident_func for local in self.locals: - object.__setattr__(local, '__ident_func__', ident_func) + object.__setattr__(local, "__ident_func__", ident_func) else: self.ident_func = get_ident @@ -224,8 +226,10 @@ class LocalManager(object): """Wrap a WSGI application so that cleaning up happens after request end. """ + def application(environ, start_response): return ClosingIterator(app(environ, start_response), self.cleanup) + return application def middleware(self, func): @@ -244,15 +248,11 @@ class LocalManager(object): return update_wrapper(self.make_middleware(func), func) def __repr__(self): - return '<%s storages: %d>' % ( - self.__class__.__name__, - len(self.locals) - ) + return "<%s storages: %d>" % (self.__class__.__name__, len(self.locals)) @implements_bool class LocalProxy(object): - """Acts as a proxy for a werkzeug local. Forwards all operations to a proxied object. The only operations not supported for forwarding are right handed operands and any kind of assignment. @@ -287,40 +287,41 @@ class LocalProxy(object): .. versionchanged:: 0.6.1 The class can be instantiated with a callable as well now. """ - __slots__ = ('__local', '__dict__', '__name__', '__wrapped__') + + __slots__ = ("__local", "__dict__", "__name__", "__wrapped__") def __init__(self, local, name=None): - object.__setattr__(self, '_LocalProxy__local', local) - object.__setattr__(self, '__name__', name) - if callable(local) and not hasattr(local, '__release_local__'): + object.__setattr__(self, "_LocalProxy__local", local) + object.__setattr__(self, "__name__", name) + if callable(local) and not hasattr(local, "__release_local__"): # "local" is a callable that is not an instance of Local or # LocalManager: mark it as a wrapped function. - object.__setattr__(self, '__wrapped__', local) + object.__setattr__(self, "__wrapped__", local) def _get_current_object(self): """Return the current object. This is useful if you want the real object behind the proxy at a time for performance reasons or because you want to pass the object into a different context. """ - if not hasattr(self.__local, '__release_local__'): + if not hasattr(self.__local, "__release_local__"): return self.__local() try: return getattr(self.__local, self.__name__) except AttributeError: - raise RuntimeError('no object bound to %s' % self.__name__) + raise RuntimeError("no object bound to %s" % self.__name__) @property def __dict__(self): try: return self._get_current_object().__dict__ except RuntimeError: - raise AttributeError('__dict__') + raise AttributeError("__dict__") def __repr__(self): try: obj = self._get_current_object() except RuntimeError: - return '<%s unbound>' % self.__class__.__name__ + return "<%s unbound>" % self.__class__.__name__ return repr(obj) def __bool__(self): @@ -342,7 +343,7 @@ class LocalProxy(object): return [] def __getattr__(self, name): - if name == '__members__': + if name == "__members__": return dir(self._get_current_object()) return getattr(self._get_current_object(), name) diff --git a/src/werkzeug/middleware/http_proxy.py b/src/werkzeug/middleware/http_proxy.py index 0890cf00..bfdc0712 100644 --- a/src/werkzeug/middleware/http_proxy.py +++ b/src/werkzeug/middleware/http_proxy.py @@ -7,7 +7,6 @@ Basic HTTP Proxy :copyright: 2007 Pallets :license: BSD-3-Clause """ - import socket from ..datastructures import EnvironHeaders @@ -117,7 +116,7 @@ class ProxyMiddleware(object): if opts["remove_prefix"]: remote_path = "%s/%s" % ( target.path.rstrip("/"), - remote_path[len(prefix):].lstrip("/"), + remote_path[len(prefix) :].lstrip("/"), ) content_length = environ.get("CONTENT_LENGTH") @@ -179,7 +178,7 @@ class ProxyMiddleware(object): resp = con.getresponse() except socket.error: - from werkzeug.exceptions import BadGateway + from ..exceptions import BadGateway return BadGateway()(environ, start_response) diff --git a/src/werkzeug/middleware/lint.py b/src/werkzeug/middleware/lint.py index 0c71f9ca..15372819 100644 --- a/src/werkzeug/middleware/lint.py +++ b/src/werkzeug/middleware/lint.py @@ -164,7 +164,7 @@ class GuardedIterator(object): content_length = headers.get("content-length", type=int) if status_code == 304: - for key, value in headers: + for key, _value in headers: key = key.lower() if key not in ("expires", "content-location") and is_entity_header( key diff --git a/src/werkzeug/middleware/profiler.py b/src/werkzeug/middleware/profiler.py index 64041f61..8e2edc20 100644 --- a/src/werkzeug/middleware/profiler.py +++ b/src/werkzeug/middleware/profiler.py @@ -12,6 +12,7 @@ that may be slowing down your application. :license: BSD-3-Clause """ from __future__ import print_function + import os.path import sys import time diff --git a/src/werkzeug/middleware/proxy_fix.py b/src/werkzeug/middleware/proxy_fix.py index 6fcd2ef6..dc1dacc8 100644 --- a/src/werkzeug/middleware/proxy_fix.py +++ b/src/werkzeug/middleware/proxy_fix.py @@ -141,7 +141,7 @@ class ProxyFix(object): "'get_remote_addr' is deprecated as of version 0.15 and" " will be removed in version 1.0. It is now handled" " internally for each header.", - DeprecationWarning + DeprecationWarning, ) return self._get_trusted_comma(self.x_for, ",".join(forwarded_for)) diff --git a/src/werkzeug/middleware/shared_data.py b/src/werkzeug/middleware/shared_data.py index 5ea3f87e..a902281d 100644 --- a/src/werkzeug/middleware/shared_data.py +++ b/src/werkzeug/middleware/shared_data.py @@ -8,7 +8,6 @@ Serve Shared Static Files :copyright: 2007 Pallets :license: BSD-3-Clause """ - import mimetypes import os import posixpath @@ -219,7 +218,7 @@ class SharedDataMiddleware(object): search_path += "/" if path.startswith(search_path): - real_filename, file_loader = loader(path[len(search_path):]) + real_filename, file_loader = loader(path[len(search_path) :]) if file_loader is not None: break diff --git a/src/werkzeug/posixemulation.py b/src/werkzeug/posixemulation.py index dbf909de..696b4562 100644 --- a/src/werkzeug/posixemulation.py +++ b/src/werkzeug/posixemulation.py @@ -17,21 +17,18 @@ r""" :copyright: 2007 Pallets :license: BSD-3-Clause """ -import sys -import os import errno -import time +import os import random +import sys +import time from ._compat import to_unicode from .filesystem import get_filesystem_encoding - can_rename_open_file = False -if os.name == 'nt': # pragma: no cover - _rename = lambda src, dst: False - _rename_atomic = lambda src, dst: False +if os.name == "nt": try: import ctypes @@ -47,8 +44,9 @@ if os.name == 'nt': # pragma: no cover retry = 0 rv = False while not rv and retry < 100: - rv = _MoveFileEx(src, dst, _MOVEFILE_REPLACE_EXISTING - | _MOVEFILE_WRITE_THROUGH) + rv = _MoveFileEx( + src, dst, _MOVEFILE_REPLACE_EXISTING | _MOVEFILE_WRITE_THROUGH + ) if not rv: time.sleep(0.001) retry += 1 @@ -62,16 +60,21 @@ if os.name == 'nt': # pragma: no cover can_rename_open_file = True def _rename_atomic(src, dst): - ta = _CreateTransaction(None, 0, 0, 0, 0, 1000, 'Werkzeug rename') + ta = _CreateTransaction(None, 0, 0, 0, 0, 1000, "Werkzeug rename") if ta == -1: return False try: retry = 0 rv = False while not rv and retry < 100: - rv = _MoveFileTransacted(src, dst, None, None, - _MOVEFILE_REPLACE_EXISTING - | _MOVEFILE_WRITE_THROUGH, ta) + rv = _MoveFileTransacted( + src, + dst, + None, + None, + _MOVEFILE_REPLACE_EXISTING | _MOVEFILE_WRITE_THROUGH, + ta, + ) if rv: rv = _CommitTransaction(ta) break @@ -81,8 +84,14 @@ if os.name == 'nt': # pragma: no cover return rv finally: _CloseHandle(ta) + except Exception: - pass + + def _rename(src, dst): + return False + + def _rename_atomic(src, dst): + return False def rename(src, dst): # Try atomic or pseudo-atomic rename @@ -101,6 +110,8 @@ if os.name == 'nt': # pragma: no cover os.unlink(old) except Exception: pass + + else: rename = os.rename can_rename_open_file = True diff --git a/src/werkzeug/routing.py b/src/werkzeug/routing.py index 3e349814..875c0154 100644 --- a/src/werkzeug/routing.py +++ b/src/werkzeug/routing.py @@ -96,30 +96,44 @@ :license: BSD-3-Clause """ import difflib -import re -import uuid -import posixpath import dis +import posixpath +import re import sys import types - +import uuid from functools import partial from pprint import pformat from threading import Lock -from werkzeug.urls import url_encode, url_quote, url_join, _fast_url_quote -from werkzeug.utils import redirect, format_string -from werkzeug.exceptions import HTTPException, NotFound, MethodNotAllowed, \ - BadHost -from werkzeug._internal import _get_environ, _encode_idna -from werkzeug._compat import itervalues, iteritems, to_unicode, to_bytes, \ - text_type, string_types, native_string_result, \ - implements_to_string, wsgi_decoding_dance -from werkzeug.datastructures import ImmutableDict, MultiDict -from werkzeug.utils import cached_property -from werkzeug.wsgi import get_host - -_rule_re = re.compile(r''' +from ._compat import implements_to_string +from ._compat import iteritems +from ._compat import itervalues +from ._compat import native_string_result +from ._compat import string_types +from ._compat import text_type +from ._compat import to_bytes +from ._compat import to_unicode +from ._compat import wsgi_decoding_dance +from ._internal import _encode_idna +from ._internal import _get_environ +from .datastructures import ImmutableDict +from .datastructures import MultiDict +from .exceptions import BadHost +from .exceptions import HTTPException +from .exceptions import MethodNotAllowed +from .exceptions import NotFound +from .urls import _fast_url_quote +from .urls import url_encode +from .urls import url_join +from .urls import url_quote +from .utils import cached_property +from .utils import format_string +from .utils import redirect +from .wsgi import get_host + +_rule_re = re.compile( + r""" (?P<static>[^<]*) # static rule data < (?: @@ -129,9 +143,12 @@ _rule_re = re.compile(r''' )? (?P<variable>[a-zA-Z_][a-zA-Z0-9_]*) # variable name > -''', re.VERBOSE) -_simple_rule_re = re.compile(r'<([^>]+)>') -_converter_args_re = re.compile(r''' + """, + re.VERBOSE, +) +_simple_rule_re = re.compile(r"<([^>]+)>") +_converter_args_re = re.compile( + r""" ((?P<name>\w+)\s*=\s*)? (?P<value> True|False| @@ -141,14 +158,12 @@ _converter_args_re = re.compile(r''' [\w\d_.]+| [urUR]?(?P<stringval>"[^"]*?"|'[^']*') )\s*, -''', re.VERBOSE | re.UNICODE) + """, + re.VERBOSE | re.UNICODE, +) -_PYTHON_CONSTANTS = { - 'None': None, - 'True': True, - 'False': False -} +_PYTHON_CONSTANTS = {"None": None, "True": True, "False": False} def _pythonize(value): @@ -159,25 +174,25 @@ def _pythonize(value): return convert(value) except ValueError: pass - if value[:1] == value[-1:] and value[0] in '"\'': + if value[:1] == value[-1:] and value[0] in "\"'": value = value[1:-1] return text_type(value) def parse_converter_args(argstr): - argstr += ',' + argstr += "," args = [] kwargs = {} for item in _converter_args_re.finditer(argstr): - value = item.group('stringval') + value = item.group("stringval") if value is None: - value = item.group('value') + value = item.group("value") value = _pythonize(value) - if not item.group('name'): + if not item.group("name"): args.append(value) else: - name = item.group('name') + name = item.group("name") kwargs[name] = value return tuple(args), kwargs @@ -199,24 +214,23 @@ def parse_rule(rule): if m is None: break data = m.groupdict() - if data['static']: - yield None, None, data['static'] - variable = data['variable'] - converter = data['converter'] or 'default' + if data["static"]: + yield None, None, data["static"] + variable = data["variable"] + converter = data["converter"] or "default" if variable in used_names: - raise ValueError('variable name %r used twice.' % variable) + raise ValueError("variable name %r used twice." % variable) used_names.add(variable) - yield converter, data['args'] or None, variable + yield converter, data["args"] or None, variable pos = m.end() if pos < end: remaining = rule[pos:] - if '>' in remaining or '<' in remaining: - raise ValueError('malformed url rule: %r' % rule) + if ">" in remaining or "<" in remaining: + raise ValueError("malformed url rule: %r" % rule) yield None, None, remaining class RoutingException(Exception): - """Special exceptions that require the application to redirect, notifying about missing urls, etc. @@ -225,12 +239,12 @@ class RoutingException(Exception): class RequestRedirect(HTTPException, RoutingException): - """Raise if the map requests a redirect. This is for example the case if `strict_slashes` are activated and an url that requires a trailing slash. The attribute `new_url` contains the absolute destination url. """ + code = 308 def __init__(self, new_url): @@ -242,12 +256,10 @@ class RequestRedirect(HTTPException, RoutingException): class RequestSlash(RoutingException): - """Internal exception.""" -class RequestAliasRedirect(RoutingException): - +class RequestAliasRedirect(RoutingException): # noqa: B903 """This rule is an alias and wants to redirect to the canonical URL.""" def __init__(self, matched_values): @@ -256,7 +268,6 @@ class RequestAliasRedirect(RoutingException): @implements_to_string class BuildError(RoutingException, LookupError): - """Raised if the build system cannot find a URL for an endpoint with the values provided. """ @@ -274,55 +285,54 @@ class BuildError(RoutingException, LookupError): def closest_rule(self, adapter): def _score_rule(rule): - return sum([ - 0.98 * difflib.SequenceMatcher( - None, rule.endpoint, self.endpoint - ).ratio(), - 0.01 * bool(set(self.values or ()).issubset(rule.arguments)), - 0.01 * bool(rule.methods and self.method in rule.methods) - ]) + return sum( + [ + 0.98 + * difflib.SequenceMatcher( + None, rule.endpoint, self.endpoint + ).ratio(), + 0.01 * bool(set(self.values or ()).issubset(rule.arguments)), + 0.01 * bool(rule.methods and self.method in rule.methods), + ] + ) if adapter and adapter.map._rules: return max(adapter.map._rules, key=_score_rule) def __str__(self): message = [] - message.append('Could not build url for endpoint %r' % self.endpoint) + message.append("Could not build url for endpoint %r" % self.endpoint) if self.method: - message.append(' (%r)' % self.method) + message.append(" (%r)" % self.method) if self.values: - message.append(' with values %r' % sorted(self.values.keys())) - message.append('.') + message.append(" with values %r" % sorted(self.values.keys())) + message.append(".") if self.suggested: if self.endpoint == self.suggested.endpoint: if self.method and self.method not in self.suggested.methods: - message.append(' Did you mean to use methods %r?' % sorted( - self.suggested.methods - )) + message.append( + " Did you mean to use methods %r?" + % sorted(self.suggested.methods) + ) missing_values = self.suggested.arguments.union( set(self.suggested.defaults or ()) ) - set(self.values.keys()) if missing_values: message.append( - ' Did you forget to specify values %r?' % - sorted(missing_values) + " Did you forget to specify values %r?" % sorted(missing_values) ) else: - message.append( - ' Did you mean %r instead?' % self.suggested.endpoint - ) - return u''.join(message) + message.append(" Did you mean %r instead?" % self.suggested.endpoint) + return u"".join(message) class ValidationError(ValueError): - """Validation error. If a rule converter raises this exception the rule does not match the current URL and the next URL is tried. """ class RuleFactory(object): - """As soon as you have more complex URL setups it's a good idea to use rule factories to avoid repetitive tasks. Some of them are builtin, others can be added by subclassing `RuleFactory` and overriding `get_rules`. @@ -335,7 +345,6 @@ class RuleFactory(object): class Subdomain(RuleFactory): - """All URLs provided by this factory have the subdomain set to a specific domain. For example if you want to use the subdomain for the current language this can be a good setup:: @@ -367,7 +376,6 @@ class Subdomain(RuleFactory): class Submount(RuleFactory): - """Like `Subdomain` but prefixes the URL rule with a given string:: url_map = Map([ @@ -382,7 +390,7 @@ class Submount(RuleFactory): """ def __init__(self, path, rules): - self.path = path.rstrip('/') + self.path = path.rstrip("/") self.rules = rules def get_rules(self, map): @@ -394,7 +402,6 @@ class Submount(RuleFactory): class EndpointPrefix(RuleFactory): - """Prefixes all endpoints (which must be strings for this factory) with another string. This can be useful for sub applications:: @@ -420,7 +427,6 @@ class EndpointPrefix(RuleFactory): class RuleTemplate(object): - """Returns copies of the rules wrapped and expands string templates in the endpoint, rule, defaults or subdomain sections. @@ -447,7 +453,6 @@ class RuleTemplate(object): class RuleTemplateFactory(RuleFactory): - """A factory that fills in template variables into rules. Used by `RuleTemplate` internally. @@ -480,13 +485,12 @@ class RuleTemplateFactory(RuleFactory): rule.methods, rule.build_only, new_endpoint, - rule.strict_slashes + rule.strict_slashes, ) @implements_to_string class Rule(RuleFactory): - """A Rule represents one URL pattern. There are some options for `Rule` that change the way it behaves and are passed to the `Rule` constructor. Note that besides the rule-string all arguments *must* be keyword arguments @@ -600,13 +604,23 @@ class Rule(RuleFactory): The `alias` and `host` parameters were added. """ - def __init__(self, string, defaults=None, subdomain=None, methods=None, - build_only=False, endpoint=None, strict_slashes=None, - redirect_to=None, alias=False, host=None): - if not string.startswith('/'): - raise ValueError('urls must start with a leading slash') + def __init__( + self, + string, + defaults=None, + subdomain=None, + methods=None, + build_only=False, + endpoint=None, + strict_slashes=None, + redirect_to=None, + alias=False, + host=None, + ): + if not string.startswith("/"): + raise ValueError("urls must start with a leading slash") self.rule = string - self.is_leaf = not string.endswith('/') + self.is_leaf = not string.endswith("/") self.map = None self.strict_slashes = strict_slashes @@ -619,10 +633,10 @@ class Rule(RuleFactory): self.methods = None else: if isinstance(methods, str): - raise TypeError('param `methods` should be `Iterable[str]`, not `str`') + raise TypeError("param `methods` should be `Iterable[str]`, not `str`") self.methods = set([x.upper() for x in methods]) - if 'HEAD' not in self.methods and 'GET' in self.methods: - self.methods.add('HEAD') + if "HEAD" not in self.methods and "GET" in self.methods: + self.methods.add("HEAD") self.endpoint = endpoint self.redirect_to = redirect_to @@ -657,11 +671,17 @@ class Rule(RuleFactory): defaults = None if self.defaults: defaults = dict(self.defaults) - return dict(defaults=defaults, subdomain=self.subdomain, - methods=self.methods, build_only=self.build_only, - endpoint=self.endpoint, strict_slashes=self.strict_slashes, - redirect_to=self.redirect_to, alias=self.alias, - host=self.host) + return dict( + defaults=defaults, + subdomain=self.subdomain, + methods=self.methods, + build_only=self.build_only, + endpoint=self.endpoint, + strict_slashes=self.strict_slashes, + redirect_to=self.redirect_to, + alias=self.alias, + host=self.host, + ) def get_rules(self, map): yield self @@ -681,8 +701,7 @@ class Rule(RuleFactory): :internal: """ if self.map is not None and not rebind: - raise RuntimeError('url rule %r already bound to map %r' % - (self, self.map)) + raise RuntimeError("url rule %r already bound to map %r" % (self, self.map)) self.map = map if self.strict_slashes is None: self.strict_slashes = map.strict_slashes @@ -696,17 +715,17 @@ class Rule(RuleFactory): .. versionadded:: 0.9 """ if converter_name not in self.map.converters: - raise LookupError('the converter %r does not exist' % converter_name) + raise LookupError("the converter %r does not exist" % converter_name) return self.map.converters[converter_name](self.map, *args, **kwargs) def compile(self): """Compiles the regular expression and stores it.""" - assert self.map is not None, 'rule not bound' + assert self.map is not None, "rule not bound" if self.map.host_matching: - domain_rule = self.host or '' + domain_rule = self.host or "" else: - domain_rule = self.subdomain or '' + domain_rule = self.subdomain or "" self._trace = [] self._converters = {} @@ -720,7 +739,7 @@ class Rule(RuleFactory): if converter is None: regex_parts.append(re.escape(variable)) self._trace.append((False, variable)) - for part in variable.split('/'): + for part in variable.split("/"): if part: self._static_weights.append((index, -len(part))) else: @@ -729,9 +748,8 @@ class Rule(RuleFactory): else: c_args = () c_kwargs = {} - convobj = self.get_converter( - variable, converter, c_args, c_kwargs) - regex_parts.append('(?P<%s>%s)' % (variable, convobj.regex)) + convobj = self.get_converter(variable, converter, c_args, c_kwargs) + regex_parts.append("(?P<%s>%s)" % (variable, convobj.regex)) self._converters[variable] = convobj self._trace.append((True, variable)) self._argument_weights.append(convobj.weight) @@ -739,21 +757,22 @@ class Rule(RuleFactory): index = index + 1 _build_regex(domain_rule) - regex_parts.append('\\|') - self._trace.append((False, '|')) - _build_regex(self.rule if self.is_leaf else self.rule.rstrip('/')) + regex_parts.append("\\|") + self._trace.append((False, "|")) + _build_regex(self.rule if self.is_leaf else self.rule.rstrip("/")) if not self.is_leaf: - self._trace.append((False, '/')) + self._trace.append((False, "/")) self._build = self._compile_builder(False) self._build_unknown = self._compile_builder(True) if self.build_only: return - regex = r'^%s%s$' % ( - u''.join(regex_parts), + regex = r"^%s%s$" % ( + u"".join(regex_parts), (not self.is_leaf or not self.strict_slashes) - and '(?<!/)(?P<__suffix__>/?)' or '' + and "(?<!/)(?P<__suffix__>/?)" + or "", ) self._regex = re.compile(regex, re.UNICODE) @@ -776,15 +795,19 @@ class Rule(RuleFactory): # slash and strict slashes enabled. raise an exception that # tells the map to redirect to the same url but with a # trailing slash - if self.strict_slashes and not self.is_leaf and \ - not groups.pop('__suffix__') and \ - (method is None or self.methods is None - or method in self.methods): + if ( + self.strict_slashes + and not self.is_leaf + and not groups.pop("__suffix__") + and ( + method is None or self.methods is None or method in self.methods + ) + ): raise RequestSlash() # if we are not in strict slashes mode we have to remove # a __suffix__ elif not self.strict_slashes: - del groups['__suffix__'] + del groups["__suffix__"] result = {} for name, value in iteritems(groups): @@ -802,7 +825,7 @@ class Rule(RuleFactory): return result class BuilderCompiler: - JOIN_EMPTY = ''.join + JOIN_EMPTY = "".join if sys.version_info >= (3, 6): OPARG_SIZE = 256 OPARG_VARI = False @@ -876,14 +899,14 @@ class Rule(RuleFactory): if op is not None: new.append((op, elem)) continue - if elem == '': + if elem == "": continue if not new or new[-1][0] is not None: new.append((op, elem)) continue new[-1] = (None, new[-1][1] + elem) if not new: - new.append((None, '')) + new.append((None, "")) return new def build_op(self, op, arg=None): @@ -891,20 +914,17 @@ class Rule(RuleFactory): if isinstance(op, str): op = dis.opmap[op] if arg is None and op >= dis.HAVE_ARGUMENT: - raise ValueError( - "Operation requires an argument: %s" % dis.opname[op]) + raise ValueError("Operation requires an argument: %s" % dis.opname[op]) if arg is not None and op < dis.HAVE_ARGUMENT: - raise ValueError( - "Operation takes no argument: %s" % dis.opname[op]) + raise ValueError("Operation takes no argument: %s" % dis.opname[op]) if arg is None: arg = 0 # Python 3.6 changed the argument to an 8-bit integer, so this # could be a practical consideration if arg >= self.OPARG_SIZE: - return ( - self.build_op('EXTENDED_ARG', arg // self.OPARG_SIZE) - + self.build_op(op, arg % self.OPARG_SIZE) - ) + return self.build_op( + "EXTENDED_ARG", arg // self.OPARG_SIZE + ) + self.build_op(op, arg % self.OPARG_SIZE) if not self.OPARG_VARI: return bytearray((op, arg)) elif op >= dis.HAVE_ARGUMENT: @@ -918,58 +938,57 @@ class Rule(RuleFactory): already be immediately below the string elements on the stack. """ - if 'BUILD_STRING' in dis.opmap: - return self.build_op('BUILD_STRING', n) + if "BUILD_STRING" in dis.opmap: + return self.build_op("BUILD_STRING", n) else: - return ( - self.build_op('BUILD_TUPLE', n) - + self.build_op('CALL_FUNCTION', 1) + return self.build_op("BUILD_TUPLE", n) + self.build_op( + "CALL_FUNCTION", 1 ) def emit_build( - self, ind, opl, append_unknown=False, encode_query_vars=None, - kwargs=None + self, ind, opl, append_unknown=False, encode_query_vars=None, kwargs=None ): - ops = b'' + ops = b"" n = len(opl) stack = 0 stack_overhead = 0 for op, elem in opl: if op is None: - ops += self.build_op('LOAD_CONST', self.get_const(elem)) + ops += self.build_op("LOAD_CONST", self.get_const(elem)) stack_overhead = 0 continue - ops += self.build_op('LOAD_CONST', self.get_const(op)) - ops += self.build_op('LOAD_FAST', self.get_var(elem)) - ops += self.build_op('CALL_FUNCTION', 1) + ops += self.build_op("LOAD_CONST", self.get_const(op)) + ops += self.build_op("LOAD_FAST", self.get_var(elem)) + ops += self.build_op("CALL_FUNCTION", 1) stack_overhead = 2 stack += len(opl) peak_stack = stack + stack_overhead dont_build_string = False - needs_build_string = 'BUILD_STRING' not in dis.opmap + needs_build_string = "BUILD_STRING" not in dis.opmap if n <= 1: dont_build_string = True needs_build_string = False if append_unknown: - if 'BUILD_STRING' not in dis.opmap: + if "BUILD_STRING" not in dis.opmap: needs_build_string = True - ops = self.build_op( - 'LOAD_CONST', self.get_const(self.JOIN_EMPTY)) + ops - ops += self.build_op('LOAD_FAST', kwargs) + ops = ( + self.build_op("LOAD_CONST", self.get_const(self.JOIN_EMPTY)) + + ops + ) + ops += self.build_op("LOAD_FAST", kwargs) # assemble this in its own buffers because we need to # jump over it uops = bytearray() # run if kwargs. TOS=kwargs - uops += self.build_op( - 'LOAD_CONST', self.get_const(encode_query_vars)) - uops += self.build_op('ROT_TWO') - uops += self.build_op('CALL_FUNCTION', 1) - uops += self.build_op('LOAD_CONST', self.get_const('?')) - uops += self.build_op('ROT_TWO') + uops += self.build_op("LOAD_CONST", self.get_const(encode_query_vars)) + uops += self.build_op("ROT_TWO") + uops += self.build_op("CALL_FUNCTION", 1) + uops += self.build_op("LOAD_CONST", self.get_const("?")) + uops += self.build_op("ROT_TWO") if dont_build_string: uops += self.build_string(n + 2) @@ -977,23 +996,23 @@ class Rule(RuleFactory): if not dont_build_string: # if we're going to build a string, we need to pad out to # a constant length - nops += self.build_op('LOAD_CONST', self.get_const('')) - nops += self.build_op('DUP_TOP') + nops += self.build_op("LOAD_CONST", self.get_const("")) + nops += self.build_op("DUP_TOP") elif needs_build_string: # we inserted the ''.join reference at the bottom of the # stack, but we don't want to call it: throw it away - nops += self.build_op('ROT_TWO') - nops += self.build_op('POP_TOP') - nops += self.build_op('JUMP_FORWARD', len(uops)) + nops += self.build_op("ROT_TWO") + nops += self.build_op("POP_TOP") + nops += self.build_op("JUMP_FORWARD", len(uops)) # this jump needs to take its own length into account. the # simple way to do that is to compute a minimal guess for the # length of the jump instruction, and keep revising it upward - jump_op = self.build_op('JUMP_IF_TRUE_OR_POP', 0) + jump_op = self.build_op("JUMP_IF_TRUE_OR_POP", 0) while True: jump_len = len(jump_op) jump_target = ind + len(ops) + jump_len + len(nops) - jump_op = self.build_op('JUMP_IF_TRUE_OR_POP', jump_target) + jump_op = self.build_op("JUMP_IF_TRUE_OR_POP", jump_target) assert len(jump_op) >= jump_len if len(jump_op) == jump_len: break @@ -1005,8 +1024,7 @@ class Rule(RuleFactory): n += 2 peak_stack = max(peak_stack, stack + 2) elif needs_build_string: - ops = self.build_op( - 'LOAD_CONST', self.get_const(self.JOIN_EMPTY)) + ops + ops = self.build_op("LOAD_CONST", self.get_const(self.JOIN_EMPTY)) + ops peak_stack += 1 if not dont_build_string: ops += self.build_string(n) @@ -1022,35 +1040,41 @@ class Rule(RuleFactory): url_encode, charset=self.rule.map.charset, sort=self.rule.map.sort_parameters, - key=self.rule.map.sort_key) + key=self.rule.map.sort_key, + ) for is_dynamic, data in self.rule._trace: - if data == '|' and opl is dom_ops: + if data == "|" and opl is dom_ops: opl = url_ops continue # this seems like a silly case to ever come up but: # if a default is given for a value that appears in the rule, # resolve it to a constant ahead of time if is_dynamic and data in self.defaults: - data = self.rule._converters[data].to_url( - self.defaults[data]) + data = self.rule._converters[data].to_url(self.defaults[data]) is_dynamic = False if not is_dynamic: - opl.append((None, url_quote( - to_bytes(data, self.rule.map.charset), safe='/:|+'))) + opl.append( + ( + None, + url_quote( + to_bytes(data, self.rule.map.charset), safe="/:|+" + ), + ) + ) continue opl.append((self.rule._converters[data].to_url, data)) dom_ops = self.collapse_constants(dom_ops) url_ops = self.collapse_constants(url_ops) - for op, elem in (dom_ops + url_ops): + for op, elem in dom_ops + url_ops: if op is not None: self.get_var(elem) self.add_defaults() argcount = len(self.var) # invalid name for paranoia reasons - self.get_var('.keyword_arguments') + self.get_var(".keyword_arguments") stack = 0 peak_stack = 0 - ops = b'' + ops = b"" if ( not append_unknown and len(dom_ops) == len(url_ops) == 1 @@ -1059,8 +1083,7 @@ class Rule(RuleFactory): # shortcut: just return the constant stack = peak_stack = 1 constant_value = (dom_ops[0][1], url_ops[0][1]) - ops += self.build_op( - 'LOAD_CONST', self.get_const(constant_value)) + ops += self.build_op("LOAD_CONST", self.get_const(constant_value)) else: ps, rv = self.emit_build(len(ops), dom_ops) ops += rv @@ -1068,14 +1091,14 @@ class Rule(RuleFactory): stack += 1 if append_unknown: ps, rv = self.emit_build( - len(ops), url_ops, append_unknown, encode_query_vars, - argcount) + len(ops), url_ops, append_unknown, encode_query_vars, argcount + ) else: ps, rv = self.emit_build(len(ops), url_ops) ops += rv peak_stack = max(stack + ps, peak_stack) - ops += self.build_op('BUILD_TUPLE', 2) - ops += self.build_op('RETURN_VALUE') + ops += self.build_op("BUILD_TUPLE", 2) + ops += self.build_op("RETURN_VALUE") code_args = [ argcount, len(self.var), @@ -1085,10 +1108,10 @@ class Rule(RuleFactory): tuple(self.consts), (), tuple(self.var), - 'generated', - '<builder:%r>' % self.rule.rule, + "generated", + "<builder:%r>" % self.rule.rule, 1, - b'' + b"", ] if sys.version_info >= (3,): code_args[1:1] = [0] @@ -1124,9 +1147,13 @@ class Rule(RuleFactory): :internal: """ - return not self.build_only and self.defaults and \ - self.endpoint == rule.endpoint and self != rule and \ - self.arguments == rule.arguments + return ( + not self.build_only + and self.defaults + and self.endpoint == rule.endpoint + and self != rule + and self.arguments == rule.arguments + ) def suitable_for(self, values, method=None): """Check if the dict of values has enough data for url generation. @@ -1135,8 +1162,11 @@ class Rule(RuleFactory): """ # if a method was given explicitly and that method is not supported # by this rule, this rule is not suitable. - if method is not None and self.methods is not None \ - and method not in self.methods: + if ( + method is not None + and self.methods is not None + and method not in self.methods + ): return False defaults = self.defaults or () @@ -1174,20 +1204,23 @@ class Rule(RuleFactory): :internal: """ - return bool(self.arguments), -len(self._static_weights), self._static_weights,\ - -len(self._argument_weights), self._argument_weights + return ( + bool(self.arguments), + -len(self._static_weights), + self._static_weights, + -len(self._argument_weights), + self._argument_weights, + ) def build_compare_key(self): """The build compare key for sorting. :internal: """ - return 1 if self.alias else 0, -len(self.arguments), \ - -len(self.defaults or ()) + return 1 if self.alias else 0, -len(self.arguments), -len(self.defaults or ()) def __eq__(self, other): - return self.__class__ is other.__class__ and \ - self._trace == other._trace + return self.__class__ is other.__class__ and self._trace == other._trace __hash__ = None @@ -1200,27 +1233,25 @@ class Rule(RuleFactory): @native_string_result def __repr__(self): if self.map is None: - return u'<%s (unbound)>' % self.__class__.__name__ + return u"<%s (unbound)>" % self.__class__.__name__ tmp = [] for is_dynamic, data in self._trace: if is_dynamic: - tmp.append(u'<%s>' % data) + tmp.append(u"<%s>" % data) else: tmp.append(data) - return u'<%s %s%s -> %s>' % ( + return u"<%s %s%s -> %s>" % ( self.__class__.__name__, - repr((u''.join(tmp)).lstrip(u'|')).lstrip(u'u'), - self.methods is not None - and u' (%s)' % u', '.join(self.methods) - or u'', - self.endpoint + repr((u"".join(tmp)).lstrip(u"|")).lstrip(u"u"), + self.methods is not None and u" (%s)" % u", ".join(self.methods) or u"", + self.endpoint, ) class BaseConverter(object): - """Base class for all converters.""" - regex = '[^/]+' + + regex = "[^/]+" weight = 100 def __init__(self, map): @@ -1234,7 +1265,6 @@ class BaseConverter(object): class UnicodeConverter(BaseConverter): - """This converter is the default converter and accepts any string but only one path segment. Thus the string can not include a slash. @@ -1255,21 +1285,17 @@ class UnicodeConverter(BaseConverter): def __init__(self, map, minlength=1, maxlength=None, length=None): BaseConverter.__init__(self, map) if length is not None: - length = '{%d}' % int(length) + length = "{%d}" % int(length) else: if maxlength is None: - maxlength = '' + maxlength = "" else: maxlength = int(maxlength) - length = '{%s,%s}' % ( - int(minlength), - maxlength - ) - self.regex = '[^/]' + length + length = "{%s,%s}" % (int(minlength), maxlength) + self.regex = "[^/]" + length class AnyConverter(BaseConverter): - """Matches one of the items provided. Items can either be Python identifiers or strings:: @@ -1282,11 +1308,10 @@ class AnyConverter(BaseConverter): def __init__(self, map, *items): BaseConverter.__init__(self, map) - self.regex = '(?:%s)' % '|'.join([re.escape(x) for x in items]) + self.regex = "(?:%s)" % "|".join([re.escape(x) for x in items]) class PathConverter(BaseConverter): - """Like the default :class:`UnicodeConverter`, but it also matches slashes. This is useful for wikis and similar applications:: @@ -1295,16 +1320,17 @@ class PathConverter(BaseConverter): :param map: the :class:`Map`. """ - regex = '[^/].*?' + + regex = "[^/].*?" weight = 200 class NumberConverter(BaseConverter): - """Baseclass for `IntegerConverter` and `FloatConverter`. :internal: """ + weight = 50 def __init__(self, map, fixed_digits=0, min=None, max=None, signed=False): @@ -1317,18 +1343,19 @@ class NumberConverter(BaseConverter): self.signed = signed def to_python(self, value): - if (self.fixed_digits and len(value) != self.fixed_digits): + if self.fixed_digits and len(value) != self.fixed_digits: raise ValidationError() value = self.num_convert(value) - if (self.min is not None and value < self.min) or \ - (self.max is not None and value > self.max): + if (self.min is not None and value < self.min) or ( + self.max is not None and value > self.max + ): raise ValidationError() return value def to_url(self, value): value = self.num_convert(value) if self.fixed_digits: - value = ('%%0%sd' % self.fixed_digits) % value + value = ("%%0%sd" % self.fixed_digits) % value return str(value) @property @@ -1337,7 +1364,6 @@ class NumberConverter(BaseConverter): class IntegerConverter(NumberConverter): - """This converter only accepts integer values:: Rule("/page/<int:page>") @@ -1358,12 +1384,12 @@ class IntegerConverter(NumberConverter): .. versionadded:: 0.15 The ``signed`` parameter. """ - regex = r'\d+' + + regex = r"\d+" num_convert = int class FloatConverter(NumberConverter): - """This converter only accepts floating point values:: Rule("/probability/<float:probability>") @@ -1381,7 +1407,8 @@ class FloatConverter(NumberConverter): .. versionadded:: 0.15 The ``signed`` parameter. """ - regex = r'\d+\.\d+' + + regex = r"\d+\.\d+" num_convert = float def __init__(self, map, min=None, max=None, signed=False): @@ -1389,7 +1416,6 @@ class FloatConverter(NumberConverter): class UUIDConverter(BaseConverter): - """This converter only accepts UUID strings:: Rule('/object/<uuid:identifier>') @@ -1398,8 +1424,11 @@ class UUIDConverter(BaseConverter): :param map: the :class:`Map`. """ - regex = r'[A-Fa-f0-9]{8}-[A-Fa-f0-9]{4}-' \ - r'[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{12}' + + regex = ( + r"[A-Fa-f0-9]{8}-[A-Fa-f0-9]{4}-" + r"[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{12}" + ) def to_python(self, value): return uuid.UUID(value) @@ -1410,18 +1439,17 @@ class UUIDConverter(BaseConverter): #: the default converter mapping for the map. DEFAULT_CONVERTERS = { - 'default': UnicodeConverter, - 'string': UnicodeConverter, - 'any': AnyConverter, - 'path': PathConverter, - 'int': IntegerConverter, - 'float': FloatConverter, - 'uuid': UUIDConverter, + "default": UnicodeConverter, + "string": UnicodeConverter, + "any": AnyConverter, + "path": PathConverter, + "int": IntegerConverter, + "float": FloatConverter, + "uuid": UUIDConverter, } class Map(object): - """The map class stores all the URL rules and some configuration parameters. Some of the configuration values are only stored on the `Map` instance since those affect all rules, others are just defaults @@ -1458,10 +1486,19 @@ class Map(object): #: A dict of default converters to be used. default_converters = ImmutableDict(DEFAULT_CONVERTERS) - def __init__(self, rules=None, default_subdomain='', charset='utf-8', - strict_slashes=True, redirect_defaults=True, - converters=None, sort_parameters=False, sort_key=None, - encoding_errors='replace', host_matching=False): + def __init__( + self, + rules=None, + default_subdomain="", + charset="utf-8", + strict_slashes=True, + redirect_defaults=True, + converters=None, + sort_parameters=False, + sort_key=None, + encoding_errors="replace", + host_matching=False, + ): self._rules = [] self._rules_by_endpoint = {} self._remap = True @@ -1528,9 +1565,16 @@ class Map(object): self._rules_by_endpoint.setdefault(rule.endpoint, []).append(rule) self._remap = True - def bind(self, server_name, script_name=None, subdomain=None, - url_scheme='http', default_method='GET', path_info=None, - query_args=None): + def bind( + self, + server_name, + script_name=None, + subdomain=None, + url_scheme="http", + default_method="GET", + path_info=None, + query_args=None, + ): """Return a new :class:`MapAdapter` with the details specified to the call. Note that `script_name` will default to ``'/'`` if not further specified or `None`. The `server_name` at least is a requirement @@ -1559,20 +1603,27 @@ class Map(object): server_name = server_name.lower() if self.host_matching: if subdomain is not None: - raise RuntimeError('host matching enabled and a ' - 'subdomain was provided') + raise RuntimeError("host matching enabled and a subdomain was provided") elif subdomain is None: subdomain = self.default_subdomain if script_name is None: - script_name = '/' + script_name = "/" if path_info is None: - path_info = '/' + path_info = "/" try: server_name = _encode_idna(server_name) except UnicodeError: raise BadHost() - return MapAdapter(self, server_name, script_name, subdomain, - url_scheme, path_info, default_method, query_args) + return MapAdapter( + self, + server_name, + script_name, + subdomain, + url_scheme, + path_info, + default_method, + query_args, + ) def bind_to_environ(self, environ, server_name=None, subdomain=None): """Like :meth:`bind` but you can pass it an WSGI environment and it @@ -1618,8 +1669,8 @@ class Map(object): server_name = server_name.lower() if subdomain is None and not self.host_matching: - cur_server_name = wsgi_server_name.split('.') - real_server_name = server_name.split('.') + cur_server_name = wsgi_server_name.split(".") + real_server_name = server_name.split(".") offset = -len(real_server_name) if cur_server_name[offset:] != real_server_name: # This can happen even with valid configs if the server was @@ -1627,22 +1678,28 @@ class Map(object): # Instead of raising an exception like in Werkzeug 0.7 or # earlier we go by an invalid subdomain which will result # in a 404 error on matching. - subdomain = '<invalid>' + subdomain = "<invalid>" else: - subdomain = '.'.join(filter(None, cur_server_name[:offset])) + subdomain = ".".join(filter(None, cur_server_name[:offset])) def _get_wsgi_string(name): val = environ.get(name) if val is not None: return wsgi_decoding_dance(val, self.charset) - script_name = _get_wsgi_string('SCRIPT_NAME') - path_info = _get_wsgi_string('PATH_INFO') - query_args = _get_wsgi_string('QUERY_STRING') - return Map.bind(self, server_name, script_name, - subdomain, environ['wsgi.url_scheme'], - environ['REQUEST_METHOD'], path_info, - query_args=query_args) + script_name = _get_wsgi_string("SCRIPT_NAME") + path_info = _get_wsgi_string("PATH_INFO") + query_args = _get_wsgi_string("QUERY_STRING") + return Map.bind( + self, + server_name, + script_name, + subdomain, + environ["wsgi.url_scheme"], + environ["REQUEST_METHOD"], + path_info, + query_args=query_args, + ) def update(self): """Called before matching and building to keep the compiled rules @@ -1662,7 +1719,7 @@ class Map(object): def __repr__(self): rules = self.iter_rules() - return '%s(%s)' % (self.__class__.__name__, pformat(list(rules))) + return "%s(%s)" % (self.__class__.__name__, pformat(list(rules))) class MapAdapter(object): @@ -1671,13 +1728,22 @@ class MapAdapter(object): the URL matching and building based on runtime information. """ - def __init__(self, map, server_name, script_name, subdomain, - url_scheme, path_info, default_method, query_args=None): + def __init__( + self, + map, + server_name, + script_name, + subdomain, + url_scheme, + path_info, + default_method, + query_args=None, + ): self.map = map self.server_name = to_unicode(server_name) script_name = to_unicode(script_name) - if not script_name.endswith(u'/'): - script_name += u'/' + if not script_name.endswith(u"/"): + script_name += u"/" self.script_name = script_name self.subdomain = to_unicode(subdomain) self.url_scheme = to_unicode(url_scheme) @@ -1685,8 +1751,9 @@ class MapAdapter(object): self.default_method = to_unicode(default_method) self.query_args = query_args - def dispatch(self, view_func, path_info=None, method=None, - catch_http_exceptions=False): + def dispatch( + self, view_func, path_info=None, method=None, catch_http_exceptions=False + ): """Does the complete dispatching process. `view_func` is called with the endpoint and a dict with the values for the view. It should look up the view function, call it, and return a response object @@ -1740,8 +1807,7 @@ class MapAdapter(object): return e raise - def match(self, path_info=None, method=None, return_rule=False, - query_args=None): + def match(self, path_info=None, method=None, return_rule=False, query_args=None): """The usage is simple: you just pass the match method the current path info as well as the method (which defaults to `GET`). The following things can then happen: @@ -1827,9 +1893,9 @@ class MapAdapter(object): query_args = self.query_args method = (method or self.default_method).upper() - path = u'%s|%s' % ( + path = u"%s|%s" % ( self.map.host_matching and self.server_name or self.subdomain, - path_info and '/%s' % path_info.lstrip('/') + path_info and "/%s" % path_info.lstrip("/"), ) have_match_for = set() @@ -1837,12 +1903,18 @@ class MapAdapter(object): try: rv = rule.match(path, method) except RequestSlash: - raise RequestRedirect(self.make_redirect_url( - url_quote(path_info, self.map.charset, - safe='/:|+') + '/', query_args)) + raise RequestRedirect( + self.make_redirect_url( + url_quote(path_info, self.map.charset, safe="/:|+") + "/", + query_args, + ) + ) except RequestAliasRedirect as e: - raise RequestRedirect(self.make_alias_redirect_url( - path, rule.endpoint, e.matched_values, method, query_args)) + raise RequestRedirect( + self.make_alias_redirect_url( + path, rule.endpoint, e.matched_values, method, query_args + ) + ) if rv is None: continue if rule.methods is not None and method not in rule.methods: @@ -1850,26 +1922,34 @@ class MapAdapter(object): continue if self.map.redirect_defaults: - redirect_url = self.get_default_redirect(rule, method, rv, - query_args) + redirect_url = self.get_default_redirect(rule, method, rv, query_args) if redirect_url is not None: raise RequestRedirect(redirect_url) if rule.redirect_to is not None: if isinstance(rule.redirect_to, string_types): + def _handle_match(match): value = rv[match.group(1)] return rule._converters[match.group(1)].to_url(value) - redirect_url = _simple_rule_re.sub(_handle_match, - rule.redirect_to) + + redirect_url = _simple_rule_re.sub(_handle_match, rule.redirect_to) else: redirect_url = rule.redirect_to(self, **rv) - raise RequestRedirect(str(url_join('%s://%s%s%s' % ( - self.url_scheme or 'http', - self.subdomain + '.' if self.subdomain else '', - self.server_name, - self.script_name - ), redirect_url))) + raise RequestRedirect( + str( + url_join( + "%s://%s%s%s" + % ( + self.url_scheme or "http", + self.subdomain + "." if self.subdomain else "", + self.server_name, + self.script_name, + ), + redirect_url, + ) + ) + ) if return_rule: return rule, rv @@ -1903,7 +1983,7 @@ class MapAdapter(object): .. versionadded:: 0.7 """ try: - self.match(path_info, method='--') + self.match(path_info, method="--") except MethodNotAllowed as e: return e.valid_methods except HTTPException: @@ -1918,13 +1998,13 @@ class MapAdapter(object): if self.map.host_matching: if domain_part is None: return self.server_name - return to_unicode(domain_part, 'ascii') + return to_unicode(domain_part, "ascii") subdomain = domain_part if subdomain is None: subdomain = self.subdomain else: - subdomain = to_unicode(subdomain, 'ascii') - return (subdomain + u'.' if subdomain else u'') + self.server_name + subdomain = to_unicode(subdomain, "ascii") + return (subdomain + u"." if subdomain else u"") + self.server_name def get_default_redirect(self, rule, method, values, query_args): """A helper that returns the URL to redirect to if it finds one. @@ -1939,12 +2019,10 @@ class MapAdapter(object): # with the highest priority up for building. if r is rule: break - if r.provides_defaults_for(rule) and \ - r.suitable_for(values, method): + if r.provides_defaults_for(rule) and r.suitable_for(values, method): values.update(r.defaults) domain_part, path = r.build(values) - return self.make_redirect_url( - path, query_args, domain_part=domain_part) + return self.make_redirect_url(path, query_args, domain_part=domain_part) def encode_query_args(self, query_args): if not isinstance(query_args, string_types): @@ -1956,25 +2034,29 @@ class MapAdapter(object): :internal: """ - suffix = '' + suffix = "" if query_args: - suffix = '?' + self.encode_query_args(query_args) - return str('%s://%s/%s%s' % ( - self.url_scheme or 'http', - self.get_host(domain_part), - posixpath.join(self.script_name[:-1].lstrip('/'), - path_info.lstrip('/')), - suffix - )) + suffix = "?" + self.encode_query_args(query_args) + return str( + "%s://%s/%s%s" + % ( + self.url_scheme or "http", + self.get_host(domain_part), + posixpath.join( + self.script_name[:-1].lstrip("/"), path_info.lstrip("/") + ), + suffix, + ) + ) def make_alias_redirect_url(self, path, endpoint, values, method, query_args): """Internally called to make an alias redirect URL.""" - url = self.build(endpoint, values, method, append_unknown=False, - force_external=True) + url = self.build( + endpoint, values, method, append_unknown=False, force_external=True + ) if query_args: - url += '?' + self.encode_query_args(query_args) - assert url != path, 'detected invalid alias setting. No canonical ' \ - 'URL found' + url += "?" + self.encode_query_args(query_args) + assert url != path, "detected invalid alias setting. No canonical URL found" return url def _partial_build(self, endpoint, values, method, append_unknown): @@ -1985,8 +2067,9 @@ class MapAdapter(object): """ # in case the method is none, try with the default method first if method is None: - rv = self._partial_build(endpoint, values, self.default_method, - append_unknown) + rv = self._partial_build( + endpoint, values, self.default_method, append_unknown + ) if rv is not None: return rv @@ -1998,8 +2081,14 @@ class MapAdapter(object): if rv is not None: return rv - def build(self, endpoint, values=None, method=None, force_external=False, - append_unknown=True): + def build( + self, + endpoint, + values=None, + method=None, + force_external=False, + append_unknown=True, + ): """Building URLs works pretty much the other way round. Instead of `match` you call `build` and pass it the endpoint and a dict of arguments for the placeholders. @@ -2100,10 +2189,13 @@ class MapAdapter(object): (self.map.host_matching and host == self.server_name) or (not self.map.host_matching and domain_part == self.subdomain) ): - return '%s/%s' % (self.script_name.rstrip('/'), path.lstrip('/')) - return str('%s//%s%s/%s' % ( - self.url_scheme + ':' if self.url_scheme else '', - host, - self.script_name[:-1], - path.lstrip('/') - )) + return "%s/%s" % (self.script_name.rstrip("/"), path.lstrip("/")) + return str( + "%s//%s%s/%s" + % ( + self.url_scheme + ":" if self.url_scheme else "", + host, + self.script_name[:-1], + path.lstrip("/"), + ) + ) diff --git a/src/werkzeug/security.py b/src/werkzeug/security.py index 6d2b8de0..1842afd0 100644 --- a/src/werkzeug/security.py +++ b/src/werkzeug/security.py @@ -8,31 +8,35 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import os -import hmac +import codecs import hashlib +import hmac +import os import posixpath -import codecs -from struct import Struct from random import SystemRandom +from struct import Struct -from werkzeug._compat import range_type, PY2, text_type, izip, to_bytes, \ - to_native - +from ._compat import izip +from ._compat import PY2 +from ._compat import range_type +from ._compat import text_type +from ._compat import to_bytes +from ._compat import to_native -SALT_CHARS = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789' +SALT_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" DEFAULT_PBKDF2_ITERATIONS = 150000 - -_pack_int = Struct('>I').pack -_builtin_safe_str_cmp = getattr(hmac, 'compare_digest', None) +_pack_int = Struct(">I").pack +_builtin_safe_str_cmp = getattr(hmac, "compare_digest", None) _sys_rng = SystemRandom() -_os_alt_seps = list(sep for sep in [os.path.sep, os.path.altsep] - if sep not in (None, '/')) +_os_alt_seps = list( + sep for sep in [os.path.sep, os.path.altsep] if sep not in (None, "/") +) -def pbkdf2_hex(data, salt, iterations=DEFAULT_PBKDF2_ITERATIONS, - keylen=None, hashfunc=None): +def pbkdf2_hex( + data, salt, iterations=DEFAULT_PBKDF2_ITERATIONS, keylen=None, hashfunc=None +): """Like :func:`pbkdf2_bin`, but returns a hex-encoded string. .. versionadded:: 0.9 @@ -47,11 +51,12 @@ def pbkdf2_hex(data, salt, iterations=DEFAULT_PBKDF2_ITERATIONS, from the hashlib module. Defaults to sha256. """ rv = pbkdf2_bin(data, salt, iterations, keylen, hashfunc) - return to_native(codecs.encode(rv, 'hex_codec')) + return to_native(codecs.encode(rv, "hex_codec")) -def pbkdf2_bin(data, salt, iterations=DEFAULT_PBKDF2_ITERATIONS, - keylen=None, hashfunc=None): +def pbkdf2_bin( + data, salt, iterations=DEFAULT_PBKDF2_ITERATIONS, keylen=None, hashfunc=None +): """Returns a binary digest for the PBKDF2 hash algorithm of `data` with the given `salt`. It iterates `iterations` times and produces a key of `keylen` bytes. By default, SHA-256 is used as hash function; @@ -69,14 +74,14 @@ def pbkdf2_bin(data, salt, iterations=DEFAULT_PBKDF2_ITERATIONS, from the hashlib module. Defaults to sha256. """ if not hashfunc: - hashfunc = 'sha256' + hashfunc = "sha256" data = to_bytes(data) salt = to_bytes(salt) if callable(hashfunc): _test_hash = hashfunc() - hash_name = getattr(_test_hash, 'name', None) + hash_name = getattr(_test_hash, "name", None) else: hash_name = hashfunc return hashlib.pbkdf2_hmac(hash_name, data, salt, iterations, keylen) @@ -91,9 +96,9 @@ def safe_str_cmp(a, b): .. versionadded:: 0.7 """ if isinstance(a, text_type): - a = a.encode('utf-8') + a = a.encode("utf-8") if isinstance(b, text_type): - b = b.encode('utf-8') + b = b.encode("utf-8") if _builtin_safe_str_cmp is not None: return _builtin_safe_str_cmp(a, b) @@ -115,8 +120,8 @@ def safe_str_cmp(a, b): def gen_salt(length): """Generate a random string of SALT_CHARS with specified ``length``.""" if length <= 0: - raise ValueError('Salt length must be positive') - return ''.join(_sys_rng.choice(SALT_CHARS) for _ in range_type(length)) + raise ValueError("Salt length must be positive") + return "".join(_sys_rng.choice(SALT_CHARS) for _ in range_type(length)) def _hash_internal(method, salt, password): @@ -124,31 +129,31 @@ def _hash_internal(method, salt, password): unsalted and salted passwords. In case salted passwords are used hmac is used. """ - if method == 'plain': + if method == "plain": return password, method if isinstance(password, text_type): - password = password.encode('utf-8') + password = password.encode("utf-8") - if method.startswith('pbkdf2:'): - args = method[7:].split(':') + if method.startswith("pbkdf2:"): + args = method[7:].split(":") if len(args) not in (1, 2): - raise ValueError('Invalid number of arguments for PBKDF2') + raise ValueError("Invalid number of arguments for PBKDF2") method = args.pop(0) iterations = args and int(args[0] or 0) or DEFAULT_PBKDF2_ITERATIONS is_pbkdf2 = True - actual_method = 'pbkdf2:%s:%d' % (method, iterations) + actual_method = "pbkdf2:%s:%d" % (method, iterations) else: is_pbkdf2 = False actual_method = method if is_pbkdf2: if not salt: - raise ValueError('Salt is required for PBKDF2') + raise ValueError("Salt is required for PBKDF2") rv = pbkdf2_hex(password, salt, iterations, hashfunc=method) elif salt: if isinstance(salt, text_type): - salt = salt.encode('utf-8') + salt = salt.encode("utf-8") mac = _create_mac(salt, password, method) rv = mac.hexdigest() else: @@ -159,14 +164,17 @@ def _hash_internal(method, salt, password): def _create_mac(key, msg, method): if callable(method): return hmac.HMAC(key, msg, method) - hashfunc = lambda d=b'': hashlib.new(method, d) + + def hashfunc(d=b""): + return hashlib.new(method, d) + # Python 2.7 used ``hasattr(digestmod, '__call__')`` # to detect if hashfunc is callable hashfunc.__call__ = hashfunc return hmac.HMAC(key, msg, hashfunc) -def generate_password_hash(password, method='pbkdf2:sha256', salt_length=8): +def generate_password_hash(password, method="pbkdf2:sha256", salt_length=8): """Hash a password with the given method and salt with a string of the given length. The format of the string returned includes the method that was used so that :func:`check_password_hash` can check the hash. @@ -191,9 +199,9 @@ def generate_password_hash(password, method='pbkdf2:sha256', salt_length=8): to enable PBKDF2. :param salt_length: the length of the salt in letters. """ - salt = gen_salt(salt_length) if method != 'plain' else '' + salt = gen_salt(salt_length) if method != "plain" else "" h, actual_method = _hash_internal(method, salt, password) - return '%s$%s$%s' % (actual_method, salt, h) + return "%s$%s$%s" % (actual_method, salt, h) def check_password_hash(pwhash, password): @@ -207,9 +215,9 @@ def check_password_hash(pwhash, password): :func:`generate_password_hash`. :param password: the plaintext password to compare against the hash. """ - if pwhash.count('$') < 2: + if pwhash.count("$") < 2: return False - method, salt, hashval = pwhash.split('$', 2) + method, salt, hashval = pwhash.split("$", 2) return safe_str_cmp(_hash_internal(method, salt, password)[0], hashval) @@ -222,14 +230,12 @@ def safe_join(directory, *pathnames): """ parts = [directory] for filename in pathnames: - if filename != '': + if filename != "": filename = posixpath.normpath(filename) for sep in _os_alt_seps: if sep in filename: return None - if os.path.isabs(filename) or \ - filename == '..' or \ - filename.startswith('../'): + if os.path.isabs(filename) or filename == ".." or filename.startswith("../"): return None parts.append(filename) return posixpath.join(*parts) diff --git a/src/werkzeug/serving.py b/src/werkzeug/serving.py index 363bb029..4a179b3e 100644 --- a/src/werkzeug/serving.py +++ b/src/werkzeug/serving.py @@ -37,72 +37,74 @@ """ import io import os +import signal import socket import sys -import signal - - -can_fork = hasattr(os, "fork") +import werkzeug +from ._compat import PY2 +from ._compat import reraise +from ._compat import WIN +from ._compat import wsgi_encoding_dance +from ._internal import _log +from .exceptions import InternalServerError +from .urls import uri_to_iri +from .urls import url_parse +from .urls import url_unquote try: - import termcolor + import socketserver + from http.server import BaseHTTPRequestHandler + from http.server import HTTPServer except ImportError: - termcolor = None + import SocketServer as socketserver + from BaseHTTPServer import HTTPServer + from BaseHTTPServer import BaseHTTPRequestHandler try: import ssl except ImportError: + class _SslDummy(object): def __getattr__(self, name): - raise RuntimeError('SSL support unavailable') + raise RuntimeError("SSL support unavailable") + ssl = _SslDummy() +try: + import termcolor +except ImportError: + termcolor = None + def _get_openssl_crypto_module(): try: from OpenSSL import crypto except ImportError: - raise TypeError('Using ad-hoc certificates requires the pyOpenSSL ' - 'library.') + raise TypeError("Using ad-hoc certificates requires the pyOpenSSL library.") else: return crypto -try: - import SocketServer as socketserver - from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler -except ImportError: - import socketserver - from http.server import HTTPServer, BaseHTTPRequestHandler - ThreadingMixIn = socketserver.ThreadingMixIn +can_fork = hasattr(os, "fork") if can_fork: ForkingMixIn = socketserver.ForkingMixIn else: + class ForkingMixIn(object): pass + try: af_unix = socket.AF_UNIX except AttributeError: af_unix = None -import werkzeug -from werkzeug._internal import _log -from werkzeug._compat import PY2, WIN, reraise, wsgi_encoding_dance -from werkzeug.urls import ( - url_parse, - url_unquote, - uri_to_iri, -) -from werkzeug.exceptions import InternalServerError - - LISTEN_QUEUE = 128 -can_open_by_fd = not WIN and hasattr(socket, 'fromfd') +can_open_by_fd = not WIN and hasattr(socket, "fromfd") # On Python 3, ConnectionError represents the same errnos as # socket.error from Python 2, while socket.error is an alias for the @@ -126,12 +128,12 @@ class DechunkedInput(io.RawIOBase): def read_chunk_len(self): try: - line = self._rfile.readline().decode('latin1') + line = self._rfile.readline().decode("latin1") _len = int(line.strip(), 16) except ValueError: - raise IOError('Invalid chunk header') + raise IOError("Invalid chunk header") if _len < 0: - raise IOError('Negative chunk length not allowed') + raise IOError("Negative chunk length not allowed") return _len def readinto(self, buf): @@ -152,7 +154,7 @@ class DechunkedInput(io.RawIOBase): # buffer. If this operation fully consumes the chunk, this will # reset self._len to 0. n = min(len(buf), self._len) - buf[read:read + n] = self._rfile.read(n) + buf[read : read + n] = self._rfile.read(n) self._len -= n read += n @@ -160,8 +162,8 @@ class DechunkedInput(io.RawIOBase): # Skip the terminating newline of a chunk that has been fully # consumed. This also applies to the 0-sized final chunk terminator = self._rfile.readline() - if terminator not in (b'\n', b'\r\n', b'\r'): - raise IOError('Missing chunk terminating newline') + if terminator not in (b"\n", b"\r\n", b"\r"): + raise IOError("Missing chunk terminating newline") return read @@ -172,7 +174,7 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): @property def server_version(self): - return 'Werkzeug/' + werkzeug.__version__ + return "Werkzeug/" + werkzeug.__version__ def make_environ(self): request_url = url_parse(self.path) @@ -180,9 +182,9 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): def shutdown_server(): self.server.shutdown_signal = True - url_scheme = 'http' if self.server.ssl_context is None else 'https' + url_scheme = "http" if self.server.ssl_context is None else "https" if not self.client_address: - self.client_address = '<local>' + self.client_address = "<local>" if isinstance(self.client_address, str): self.client_address = (self.client_address, 0) else: @@ -190,57 +192,57 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): path_info = url_unquote(request_url.path) environ = { - 'wsgi.version': (1, 0), - 'wsgi.url_scheme': url_scheme, - 'wsgi.input': self.rfile, - 'wsgi.errors': sys.stderr, - 'wsgi.multithread': self.server.multithread, - 'wsgi.multiprocess': self.server.multiprocess, - 'wsgi.run_once': False, - 'werkzeug.server.shutdown': shutdown_server, - 'SERVER_SOFTWARE': self.server_version, - 'REQUEST_METHOD': self.command, - 'SCRIPT_NAME': '', - 'PATH_INFO': wsgi_encoding_dance(path_info), - 'QUERY_STRING': wsgi_encoding_dance(request_url.query), + "wsgi.version": (1, 0), + "wsgi.url_scheme": url_scheme, + "wsgi.input": self.rfile, + "wsgi.errors": sys.stderr, + "wsgi.multithread": self.server.multithread, + "wsgi.multiprocess": self.server.multiprocess, + "wsgi.run_once": False, + "werkzeug.server.shutdown": shutdown_server, + "SERVER_SOFTWARE": self.server_version, + "REQUEST_METHOD": self.command, + "SCRIPT_NAME": "", + "PATH_INFO": wsgi_encoding_dance(path_info), + "QUERY_STRING": wsgi_encoding_dance(request_url.query), # Non-standard, added by mod_wsgi, uWSGI "REQUEST_URI": wsgi_encoding_dance(self.path), # Non-standard, added by gunicorn "RAW_URI": wsgi_encoding_dance(self.path), - 'REMOTE_ADDR': self.address_string(), - 'REMOTE_PORT': self.port_integer(), - 'SERVER_NAME': self.server.server_address[0], - 'SERVER_PORT': str(self.server.server_address[1]), - 'SERVER_PROTOCOL': self.request_version + "REMOTE_ADDR": self.address_string(), + "REMOTE_PORT": self.port_integer(), + "SERVER_NAME": self.server.server_address[0], + "SERVER_PORT": str(self.server.server_address[1]), + "SERVER_PROTOCOL": self.request_version, } for key, value in self.get_header_items(): - key = key.upper().replace('-', '_') - if key not in ('CONTENT_TYPE', 'CONTENT_LENGTH'): - key = 'HTTP_' + key + key = key.upper().replace("-", "_") + if key not in ("CONTENT_TYPE", "CONTENT_LENGTH"): + key = "HTTP_" + key if key in environ: value = "{},{}".format(environ[key], value) environ[key] = value - if environ.get('HTTP_TRANSFER_ENCODING', '').strip().lower() == 'chunked': - environ['wsgi.input_terminated'] = True - environ['wsgi.input'] = DechunkedInput(environ['wsgi.input']) + if environ.get("HTTP_TRANSFER_ENCODING", "").strip().lower() == "chunked": + environ["wsgi.input_terminated"] = True + environ["wsgi.input"] = DechunkedInput(environ["wsgi.input"]) if request_url.scheme and request_url.netloc: - environ['HTTP_HOST'] = request_url.netloc + environ["HTTP_HOST"] = request_url.netloc return environ def run_wsgi(self): - if self.headers.get('Expect', '').lower().strip() == '100-continue': - self.wfile.write(b'HTTP/1.1 100 Continue\r\n\r\n') + if self.headers.get("Expect", "").lower().strip() == "100-continue": + self.wfile.write(b"HTTP/1.1 100 Continue\r\n\r\n") self.environ = environ = self.make_environ() headers_set = [] headers_sent = [] def write(data): - assert headers_set, 'write() before start_response' + assert headers_set, "write() before start_response" if not headers_sent: status, response_headers = headers_sent[:] = headers_set try: @@ -254,18 +256,21 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): self.send_header(key, value) key = key.lower() header_keys.add(key) - if not ('content-length' in header_keys - or environ['REQUEST_METHOD'] == 'HEAD' - or code < 200 or code in (204, 304)): + if not ( + "content-length" in header_keys + or environ["REQUEST_METHOD"] == "HEAD" + or code < 200 + or code in (204, 304) + ): self.close_connection = True - self.send_header('Connection', 'close') - if 'server' not in header_keys: - self.send_header('Server', self.version_string()) - if 'date' not in header_keys: - self.send_header('Date', self.date_time_string()) + self.send_header("Connection", "close") + if "server" not in header_keys: + self.send_header("Server", self.version_string()) + if "date" not in header_keys: + self.send_header("Date", self.date_time_string()) self.end_headers() - assert isinstance(data, bytes), 'applications must write bytes' + assert isinstance(data, bytes), "applications must write bytes" self.wfile.write(data) self.wfile.flush() @@ -277,7 +282,7 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): finally: exc_info = None elif headers_set: - raise AssertionError('Headers already set') + raise AssertionError("Headers already set") headers_set[:] = [status, response_headers] return write @@ -287,9 +292,9 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): for data in application_iter: write(data) if not headers_sent: - write(b'') + write(b"") finally: - if hasattr(application_iter, 'close'): + if hasattr(application_iter, "close"): application_iter.close() application_iter = None @@ -300,7 +305,8 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): except Exception: if self.server.passthrough_errors: raise - from werkzeug.debug.tbtools import get_current_traceback + from .debug.tbtools import get_current_traceback + traceback = get_current_traceback(ignore_system_exceptions=True) try: # if we haven't yet sent the headers but they are set @@ -310,8 +316,7 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): execute(InternalServerError()) except Exception: pass - self.server.log('error', 'Error on request:\n%s', - traceback.plaintext) + self.server.log("error", "Error on request:\n%s", traceback.plaintext) def handle(self): """Handles a request ignoring dropped connections.""" @@ -332,7 +337,7 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): later. It's the best we can do. """ # Windows does not provide SIGKILL, go with SIGTERM then. - sig = getattr(signal, 'SIGKILL', signal.SIGTERM) + sig = getattr(signal, "SIGKILL", signal.SIGTERM) # reloader active if is_running_from_reloader(): os.kill(os.getpid(), sig) @@ -358,19 +363,19 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): """Send the response header and log the response code.""" self.log_request(code) if message is None: - message = code in self.responses and self.responses[code][0] or '' - if self.request_version != 'HTTP/0.9': + message = code in self.responses and self.responses[code][0] or "" + if self.request_version != "HTTP/0.9": hdr = "%s %d %s\r\n" % (self.protocol_version, code, message) - self.wfile.write(hdr.encode('ascii')) + self.wfile.write(hdr.encode("ascii")) def version_string(self): return BaseHTTPRequestHandler.version_string(self).strip() def address_string(self): - if getattr(self, 'environ', None): - return self.environ['REMOTE_ADDR'] + if getattr(self, "environ", None): + return self.environ["REMOTE_ADDR"] elif not self.client_address: - return '<local>' + return "<local>" elif isinstance(self.client_address, str): return self.client_address else: @@ -379,7 +384,7 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): def port_integer(self): return self.client_address[1] - def log_request(self, code='-', size='-'): + def log_request(self, code="-", size="-"): try: path = uri_to_iri(self.path) msg = "%s %s %s" % (self.command, path, self.request_version) @@ -392,33 +397,35 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): if termcolor: color = termcolor.colored - if code[0] == '1': # 1xx - Informational - msg = color(msg, attrs=['bold']) - elif code[0] == '2': # 2xx - Success - msg = color(msg, color='white') - elif code == '304': # 304 - Resource Not Modified - msg = color(msg, color='cyan') - elif code[0] == '3': # 3xx - Redirection - msg = color(msg, color='green') - elif code == '404': # 404 - Resource Not Found - msg = color(msg, color='yellow') - elif code[0] == '4': # 4xx - Client Error - msg = color(msg, color='red', attrs=['bold']) - else: # 5xx, or any other response - msg = color(msg, color='magenta', attrs=['bold']) - - self.log('info', '"%s" %s %s', msg, code, size) + if code[0] == "1": # 1xx - Informational + msg = color(msg, attrs=["bold"]) + elif code[0] == "2": # 2xx - Success + msg = color(msg, color="white") + elif code == "304": # 304 - Resource Not Modified + msg = color(msg, color="cyan") + elif code[0] == "3": # 3xx - Redirection + msg = color(msg, color="green") + elif code == "404": # 404 - Resource Not Found + msg = color(msg, color="yellow") + elif code[0] == "4": # 4xx - Client Error + msg = color(msg, color="red", attrs=["bold"]) + else: # 5xx, or any other response + msg = color(msg, color="magenta", attrs=["bold"]) + + self.log("info", '"%s" %s %s', msg, code, size) def log_error(self, *args): - self.log('error', *args) + self.log("error", *args) def log_message(self, format, *args): - self.log('info', format, *args) + self.log("info", format, *args) def log(self, type, message, *args): - _log(type, '%s - - [%s] %s\n' % (self.address_string(), - self.log_date_time_string(), - message % args)) + _log( + type, + "%s - - [%s] %s\n" + % (self.address_string(), self.log_date_time_string(), message % args), + ) def get_header_items(self): """ @@ -434,15 +441,18 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): :return: List of tuples containing header hey/value pairs """ if PY2: - # For Python 2, process the headers manually according to W3C RFC 2616 Section 4.2 + # For Python 2, process the headers manually according to + # W3C RFC 2616 Section 4.2. items = [] for header in self.headers.headers: - # Remove the \n\r from the header and split on the : to get the field name and value + # Remove "\n\r" from the header and split on ":" to get + # the field name and value. key, value = header[0:-2].split(":", 1) - # Add the key and the value once stripped of leading white space. The specification - # allows for stripping trailing white space but the Python 3 code does not strip - # trailing white space. Therefore, trailing space will be left as is to match the - # Python 3 behavior + # Add the key and the value once stripped of leading + # white space. The specification allows for stripping + # trailing white space but the Python 3 code does not + # strip trailing white space. Therefore, trailing space + # will be left as is to match the Python 3 behavior. items.append((key, value.lstrip())) else: items = self.headers.items() @@ -456,11 +466,12 @@ BaseRequestHandler = WSGIRequestHandler def generate_adhoc_ssl_pair(cn=None): from random import random + crypto = _get_openssl_crypto_module() # pretty damn sure that this is not actually accepted by anyone if cn is None: - cn = '*' + cn = "*" cert = crypto.X509() cert.set_serial_number(int(random() * sys.maxsize)) @@ -469,7 +480,7 @@ def generate_adhoc_ssl_pair(cn=None): subject = cert.get_subject() subject.CN = cn - subject.O = 'Dummy Certificate' # noqa: E741 + subject.O = "Dummy Certificate" # noqa: E741 issuer = cert.get_issuer() issuer.CN = subject.CN @@ -478,7 +489,7 @@ def generate_adhoc_ssl_pair(cn=None): pkey = crypto.PKey() pkey.generate_key(crypto.TYPE_RSA, 2048) cert.set_pubkey(pkey) - cert.sign(pkey, 'sha256') + cert.sign(pkey, "sha256") return cert, pkey @@ -502,16 +513,17 @@ def make_ssl_devcert(base_path, host=None, cn=None): :param cn: the `CN` to use. """ from OpenSSL import crypto + if host is not None: - cn = '*.%s/CN=%s' % (host, host) + cn = "*.%s/CN=%s" % (host, host) cert, pkey = generate_adhoc_ssl_pair(cn=cn) - cert_file = base_path + '.crt' - pkey_file = base_path + '.key' + cert_file = base_path + ".crt" + pkey_file = base_path + ".key" - with open(cert_file, 'wb') as f: + with open(cert_file, "wb") as f: f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert)) - with open(pkey_file, 'wb') as f: + with open(pkey_file, "wb") as f: f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, pkey)) return cert_file, pkey_file @@ -557,8 +569,8 @@ def load_ssl_context(cert_file, pkey_file=None, protocol=None): class _SSLContext(object): - '''A dummy class with a small subset of Python3's ``ssl.SSLContext``, only - intended to be used with and by Werkzeug.''' + """A dummy class with a small subset of Python3's ``ssl.SSLContext``, only + intended to be used with and by Werkzeug.""" def __init__(self, protocol): self._protocol = protocol @@ -572,9 +584,13 @@ class _SSLContext(object): self._password = password def wrap_socket(self, sock, **kwargs): - return ssl.wrap_socket(sock, keyfile=self._keyfile, - certfile=self._certfile, - ssl_version=self._protocol, **kwargs) + return ssl.wrap_socket( + sock, + keyfile=self._keyfile, + certfile=self._certfile, + ssl_version=self._protocol, + **kwargs + ) def is_ssl_error(error=None): @@ -582,6 +598,7 @@ def is_ssl_error(error=None): exc_types = (ssl.SSLError,) try: from OpenSSL.SSL import Error + exc_types += (Error,) except ImportError: pass @@ -606,9 +623,9 @@ def select_address_family(host, port): # return info[0][0] # except socket.gaierror: # pass - if host.startswith('unix://'): + if host.startswith("unix://"): return socket.AF_UNIX - elif ':' in host and hasattr(socket, 'AF_INET6'): + elif ":" in host and hasattr(socket, "AF_INET6"): return socket.AF_INET6 return socket.AF_INET @@ -617,10 +634,11 @@ def get_sockaddr(host, port, family): """Return a fully qualified socket address that can be passed to :func:`socket.bind`.""" if family == af_unix: - return host.split('://', 1)[1] + return host.split("://", 1)[1] try: res = socket.getaddrinfo( - host, port, family, socket.SOCK_STREAM, socket.IPPROTO_TCP) + host, port, family, socket.SOCK_STREAM, socket.IPPROTO_TCP + ) except socket.gaierror: return host, port return res[0][4] @@ -629,13 +647,20 @@ def get_sockaddr(host, port, family): class BaseWSGIServer(HTTPServer, object): """Simple single-threaded, single-process WSGI server.""" + multithread = False multiprocess = False request_queue_size = LISTEN_QUEUE def __init__( - self, host, port, app, handler=None, passthrough_errors=False, - ssl_context=None, fd=None + self, + host, + port, + app, + handler=None, + passthrough_errors=False, + ssl_context=None, + fd=None, ): if handler is None: handler = WSGIRequestHandler @@ -643,17 +668,13 @@ class BaseWSGIServer(HTTPServer, object): self.address_family = select_address_family(host, port) if fd is not None: - real_sock = socket.fromfd( - fd, self.address_family, socket.SOCK_STREAM) + real_sock = socket.fromfd(fd, self.address_family, socket.SOCK_STREAM) port = 0 server_address = get_sockaddr(host, int(port), self.address_family) # remove socket file if it already exists - if ( - self.address_family == af_unix - and os.path.exists(server_address) - ): + if self.address_family == af_unix and os.path.exists(server_address): os.unlink(server_address) HTTPServer.__init__(self, server_address, handler) @@ -672,7 +693,7 @@ class BaseWSGIServer(HTTPServer, object): if ssl_context is not None: if isinstance(ssl_context, tuple): ssl_context = load_ssl_context(*ssl_context) - if ssl_context == 'adhoc': + if ssl_context == "adhoc": ssl_context = generate_adhoc_ssl_context() # If we are on Python 2 the return value from socket.fromfd # is an internal socket object but what we need for ssl wrap @@ -714,6 +735,7 @@ class BaseWSGIServer(HTTPServer, object): class ThreadedWSGIServer(ThreadingMixIn, BaseWSGIServer): """A WSGI server that does threading.""" + multithread = True daemon_threads = True @@ -721,35 +743,63 @@ class ThreadedWSGIServer(ThreadingMixIn, BaseWSGIServer): class ForkingWSGIServer(ForkingMixIn, BaseWSGIServer): """A WSGI server that does forking.""" + multiprocess = True - def __init__(self, host, port, app, processes=40, handler=None, - passthrough_errors=False, ssl_context=None, fd=None): + def __init__( + self, + host, + port, + app, + processes=40, + handler=None, + passthrough_errors=False, + ssl_context=None, + fd=None, + ): if not can_fork: - raise ValueError('Your platform does not support forking.') - BaseWSGIServer.__init__(self, host, port, app, handler, - passthrough_errors, ssl_context, fd) + raise ValueError("Your platform does not support forking.") + BaseWSGIServer.__init__( + self, host, port, app, handler, passthrough_errors, ssl_context, fd + ) self.max_children = processes -def make_server(host=None, port=None, app=None, threaded=False, processes=1, - request_handler=None, passthrough_errors=False, - ssl_context=None, fd=None): +def make_server( + host=None, + port=None, + app=None, + threaded=False, + processes=1, + request_handler=None, + passthrough_errors=False, + ssl_context=None, + fd=None, +): """Create a new server instance that is either threaded, or forks or just processes one request after another. """ if threaded and processes > 1: - raise ValueError("cannot have a multithreaded and " - "multi process server.") + raise ValueError("cannot have a multithreaded and multi process server.") elif threaded: - return ThreadedWSGIServer(host, port, app, request_handler, - passthrough_errors, ssl_context, fd=fd) + return ThreadedWSGIServer( + host, port, app, request_handler, passthrough_errors, ssl_context, fd=fd + ) elif processes > 1: - return ForkingWSGIServer(host, port, app, processes, request_handler, - passthrough_errors, ssl_context, fd=fd) + return ForkingWSGIServer( + host, + port, + app, + processes, + request_handler, + passthrough_errors, + ssl_context, + fd=fd, + ) else: - return BaseWSGIServer(host, port, app, request_handler, - passthrough_errors, ssl_context, fd=fd) + return BaseWSGIServer( + host, port, app, request_handler, passthrough_errors, ssl_context, fd=fd + ) def is_running_from_reloader(): @@ -758,15 +808,26 @@ def is_running_from_reloader(): .. versionadded:: 0.10 """ - return os.environ.get('WERKZEUG_RUN_MAIN') == 'true' - - -def run_simple(hostname, port, application, use_reloader=False, - use_debugger=False, use_evalex=True, - extra_files=None, reloader_interval=1, - reloader_type='auto', threaded=False, - processes=1, request_handler=None, static_files=None, - passthrough_errors=False, ssl_context=None): + return os.environ.get("WERKZEUG_RUN_MAIN") == "true" + + +def run_simple( + hostname, + port, + application, + use_reloader=False, + use_debugger=False, + use_evalex=True, + extra_files=None, + reloader_interval=1, + reloader_type="auto", + threaded=False, + processes=1, + request_handler=None, + static_files=None, + passthrough_errors=False, + ssl_context=None, +): """Start a WSGI application. Optional features include a reloader, multithreading and fork support. @@ -837,36 +898,50 @@ def run_simple(hostname, port, application, use_reloader=False, to disable SSL (which is the default). """ if not isinstance(port, int): - raise TypeError('port must be an integer') + raise TypeError("port must be an integer") if use_debugger: - from werkzeug.debug import DebuggedApplication + from .debug import DebuggedApplication + application = DebuggedApplication(application, use_evalex) if static_files: - from werkzeug.wsgi import SharedDataMiddleware + from .middleware.shared_data import SharedDataMiddleware + application = SharedDataMiddleware(application, static_files) def log_startup(sock): - display_hostname = hostname if hostname not in ('', '*') else 'localhost' - quit_msg = '(Press CTRL+C to quit)' + display_hostname = hostname if hostname not in ("", "*") else "localhost" + quit_msg = "(Press CTRL+C to quit)" if sock.family == af_unix: - _log('info', ' * Running on %s %s', display_hostname, quit_msg) + _log("info", " * Running on %s %s", display_hostname, quit_msg) else: - if ':' in display_hostname: - display_hostname = '[%s]' % display_hostname + if ":" in display_hostname: + display_hostname = "[%s]" % display_hostname port = sock.getsockname()[1] - _log('info', ' * Running on %s://%s:%d/ %s', - 'http' if ssl_context is None else 'https', - display_hostname, port, quit_msg) + _log( + "info", + " * Running on %s://%s:%d/ %s", + "http" if ssl_context is None else "https", + display_hostname, + port, + quit_msg, + ) def inner(): try: - fd = int(os.environ['WERKZEUG_SERVER_FD']) + fd = int(os.environ["WERKZEUG_SERVER_FD"]) except (LookupError, ValueError): fd = None - srv = make_server(hostname, port, application, threaded, - processes, request_handler, - passthrough_errors, ssl_context, - fd=fd) + srv = make_server( + hostname, + port, + application, + threaded, + processes, + request_handler, + passthrough_errors, + ssl_context, + fd=fd, + ) if fd is None: log_startup(srv.socket) srv.serve_forever() @@ -877,9 +952,11 @@ def run_simple(hostname, port, application, use_reloader=False, # port is actually available. if not is_running_from_reloader(): if port == 0 and not can_open_by_fd: - raise ValueError('Cannot bind to a random port with enabled ' - 'reloader if the Python interpreter does ' - 'not support socket opening by fd.') + raise ValueError( + "Cannot bind to a random port with enabled " + "reloader if the Python interpreter does " + "not support socket opening by fd." + ) # Create and destroy a socket so that any exceptions are # raised before we spawn a separate Python interpreter and @@ -889,26 +966,26 @@ def run_simple(hostname, port, application, use_reloader=False, s = socket.socket(address_family, socket.SOCK_STREAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind(server_address) - if hasattr(s, 'set_inheritable'): + if hasattr(s, "set_inheritable"): s.set_inheritable(True) # If we can open the socket by file descriptor, then we can just # reuse this one and our socket will survive the restarts. if can_open_by_fd: - os.environ['WERKZEUG_SERVER_FD'] = str(s.fileno()) + os.environ["WERKZEUG_SERVER_FD"] = str(s.fileno()) s.listen(LISTEN_QUEUE) log_startup(s) else: s.close() if address_family == af_unix: - _log('info', "Unlinking %s" % server_address) + _log("info", "Unlinking %s" % server_address) os.unlink(server_address) # Do not use relative imports, otherwise "python -m werkzeug.serving" # breaks. - from werkzeug._reloader import run_with_reloader - run_with_reloader(inner, extra_files, reloader_interval, - reloader_type) + from ._reloader import run_with_reloader + + run_with_reloader(inner, extra_files, reloader_interval, reloader_type) else: inner() @@ -916,46 +993,63 @@ def run_simple(hostname, port, application, use_reloader=False, def run_with_reloader(*args, **kwargs): # People keep using undocumented APIs. Do not use this function # please, we do not guarantee that it continues working. - from werkzeug._reloader import run_with_reloader + from ._reloader import run_with_reloader + return run_with_reloader(*args, **kwargs) def main(): - '''A simple command-line interface for :py:func:`run_simple`.''' + """A simple command-line interface for :py:func:`run_simple`.""" # in contrast to argparse, this works at least under Python < 2.7 import optparse - from werkzeug.utils import import_string - - parser = optparse.OptionParser( - usage='Usage: %prog [options] app_module:app_object') - parser.add_option('-b', '--bind', dest='address', - help='The hostname:port the app should listen on.') - parser.add_option('-d', '--debug', dest='use_debugger', - action='store_true', default=False, - help='Use Werkzeug\'s debugger.') - parser.add_option('-r', '--reload', dest='use_reloader', - action='store_true', default=False, - help='Reload Python process if modules change.') + from .utils import import_string + + parser = optparse.OptionParser(usage="Usage: %prog [options] app_module:app_object") + parser.add_option( + "-b", + "--bind", + dest="address", + help="The hostname:port the app should listen on.", + ) + parser.add_option( + "-d", + "--debug", + dest="use_debugger", + action="store_true", + default=False, + help="Use Werkzeug's debugger.", + ) + parser.add_option( + "-r", + "--reload", + dest="use_reloader", + action="store_true", + default=False, + help="Reload Python process if modules change.", + ) options, args = parser.parse_args() hostname, port = None, None if options.address: - address = options.address.split(':') + address = options.address.split(":") hostname = address[0] if len(address) > 1: port = address[1] if len(args) != 1: - sys.stdout.write('No application supplied, or too much. See --help\n') + sys.stdout.write("No application supplied, or too much. See --help\n") sys.exit(1) app = import_string(args[0]) run_simple( - hostname=(hostname or '127.0.0.1'), port=int(port or 5000), - application=app, use_reloader=options.use_reloader, - use_debugger=options.use_debugger + hostname=(hostname or "127.0.0.1"), + port=int(port or 5000), + application=app, + use_reloader=options.use_reloader, + use_debugger=options.use_debugger, ) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/src/werkzeug/test.py b/src/werkzeug/test.py index 8f04ec21..9aeb4127 100644 --- a/src/werkzeug/test.py +++ b/src/werkzeug/test.py @@ -8,56 +8,69 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import sys import mimetypes -from time import time - -from random import random +import sys +from io import BytesIO from itertools import chain +from random import random from tempfile import TemporaryFile -from io import BytesIO +from time import time + +from ._compat import iteritems +from ._compat import iterlists +from ._compat import itervalues +from ._compat import make_literal_wrapper +from ._compat import reraise +from ._compat import string_types +from ._compat import text_type +from ._compat import to_bytes +from ._compat import wsgi_encoding_dance +from ._internal import _get_environ +from .datastructures import CallbackDict +from .datastructures import CombinedMultiDict +from .datastructures import EnvironHeaders +from .datastructures import FileMultiDict +from .datastructures import FileStorage +from .datastructures import Headers +from .datastructures import MultiDict +from .http import dump_cookie +from .http import dump_options_header +from .http import parse_options_header +from .urls import iri_to_uri +from .urls import url_encode +from .urls import url_fix +from .urls import url_parse +from .urls import url_unparse +from .urls import url_unquote +from .utils import get_content_type +from .wrappers import BaseRequest +from .wsgi import ClosingIterator +from .wsgi import get_current_url try: - from urllib2 import Request as U2Request -except ImportError: from urllib.request import Request as U2Request +except ImportError: + from urllib2 import Request as U2Request + try: from http.cookiejar import CookieJar -except ImportError: # Py2 +except ImportError: from cookielib import CookieJar -from werkzeug._compat import iterlists, iteritems, itervalues, to_bytes, \ - string_types, text_type, reraise, wsgi_encoding_dance, \ - make_literal_wrapper -from werkzeug._internal import _get_environ -from werkzeug.wrappers import BaseRequest -from werkzeug.urls import url_encode, url_fix, iri_to_uri, url_unquote, \ - url_unparse, url_parse -from werkzeug.wsgi import get_current_url, ClosingIterator -from werkzeug.utils import dump_cookie, get_content_type -from werkzeug.datastructures import ( - FileMultiDict, - MultiDict, - CombinedMultiDict, - Headers, - FileStorage, - CallbackDict, - EnvironHeaders, -) -from werkzeug.http import dump_options_header, parse_options_header - - -def stream_encode_multipart(values, use_tempfile=True, threshold=1024 * 500, - boundary=None, charset='utf-8'): + +def stream_encode_multipart( + values, use_tempfile=True, threshold=1024 * 500, boundary=None, charset="utf-8" +): """Encode a dict of values (either strings or file descriptors or :class:`FileStorage` objects.) into a multipart encoded string stored in a file descriptor. """ if boundary is None: - boundary = '---------------WerkzeugFormPart_%s%s' % (time(), random()) + boundary = "---------------WerkzeugFormPart_%s%s" % (time(), random()) _closure = [BytesIO(), 0, False] if use_tempfile: + def write_binary(string): stream, total_length, on_disk = _closure if on_disk: @@ -67,12 +80,13 @@ def stream_encode_multipart(values, use_tempfile=True, threshold=1024 * 500, if length + _closure[1] <= threshold: stream.write(string) else: - new_stream = TemporaryFile('wb+') + new_stream = TemporaryFile("wb+") new_stream.write(stream.getvalue()) new_stream.write(string) _closure[0] = new_stream _closure[2] = True _closure[1] = total_length + length + else: write_binary = _closure[0].write @@ -84,22 +98,22 @@ def stream_encode_multipart(values, use_tempfile=True, threshold=1024 * 500, for key, values in iterlists(values): for value in values: - write('--%s\r\nContent-Disposition: form-data; name="%s"' % - (boundary, key)) - reader = getattr(value, 'read', None) + write('--%s\r\nContent-Disposition: form-data; name="%s"' % (boundary, key)) + reader = getattr(value, "read", None) if reader is not None: - filename = getattr(value, 'filename', - getattr(value, 'name', None)) - content_type = getattr(value, 'content_type', None) + filename = getattr(value, "filename", getattr(value, "name", None)) + content_type = getattr(value, "content_type", None) if content_type is None: - content_type = filename and \ - mimetypes.guess_type(filename)[0] or \ - 'application/octet-stream' + content_type = ( + filename + and mimetypes.guess_type(filename)[0] + or "application/octet-stream" + ) if filename is not None: write('; filename="%s"\r\n' % filename) else: - write('\r\n') - write('Content-Type: %s\r\n\r\n' % content_type) + write("\r\n") + write("Content-Type: %s\r\n\r\n" % content_type) while 1: chunk = reader(16384) if not chunk: @@ -110,22 +124,23 @@ def stream_encode_multipart(values, use_tempfile=True, threshold=1024 * 500, value = str(value) value = to_bytes(value, charset) - write('\r\n\r\n') + write("\r\n\r\n") write_binary(value) - write('\r\n') - write('--%s--\r\n' % boundary) + write("\r\n") + write("--%s--\r\n" % boundary) length = int(_closure[0].tell()) _closure[0].seek(0) return _closure[0], length, boundary -def encode_multipart(values, boundary=None, charset='utf-8'): +def encode_multipart(values, boundary=None, charset="utf-8"): """Like `stream_encode_multipart` but returns a tuple in the form (``boundary``, ``data``) where data is a bytestring. """ stream, length, boundary = stream_encode_multipart( - values, use_tempfile=False, boundary=boundary, charset=charset) + values, use_tempfile=False, boundary=boundary, charset=charset + ) return boundary, stream.read() @@ -135,6 +150,7 @@ def File(fd, filename=None, mimetype=None): .. deprecated:: 0.5 """ from warnings import warn + warn( "'werkzeug.test.File' is deprecated as of version 0.5 and will" " be removed in version 1.0. Use 'EnvironBuilder' or" @@ -194,17 +210,16 @@ class _TestCookieJar(CookieJar): """ cvals = [] for cookie in self: - cvals.append('%s=%s' % (cookie.name, cookie.value)) + cvals.append("%s=%s" % (cookie.name, cookie.value)) if cvals: - environ['HTTP_COOKIE'] = '; '.join(cvals) + environ["HTTP_COOKIE"] = "; ".join(cvals) def extract_wsgi(self, environ, headers): """Extract the server's set-cookie headers as cookies into the cookie jar. """ self.extract_cookies( - _TestCookieResponse(headers), - U2Request(get_current_url(environ)), + _TestCookieResponse(headers), U2Request(get_current_url(environ)) ) @@ -307,7 +322,7 @@ class EnvironBuilder(object): """ #: the server protocol to use. defaults to HTTP/1.1 - server_protocol = 'HTTP/1.1' + server_protocol = "HTTP/1.1" #: the wsgi version to use. defaults to (1, 0) wsgi_version = (1, 0) @@ -316,21 +331,37 @@ class EnvironBuilder(object): request_class = BaseRequest import json + #: The serialization function used when ``json`` is passed. json_dumps = staticmethod(json.dumps) del json - def __init__(self, path='/', base_url=None, query_string=None, - method='GET', input_stream=None, content_type=None, - content_length=None, errors_stream=None, multithread=False, - multiprocess=False, run_once=False, headers=None, data=None, - environ_base=None, environ_overrides=None, charset='utf-8', - mimetype=None, json=None): + def __init__( + self, + path="/", + base_url=None, + query_string=None, + method="GET", + input_stream=None, + content_type=None, + content_length=None, + errors_stream=None, + multithread=False, + multiprocess=False, + run_once=False, + headers=None, + data=None, + environ_base=None, + environ_overrides=None, + charset="utf-8", + mimetype=None, + json=None, + ): path_s = make_literal_wrapper(path) - if query_string is not None and path_s('?') in path: - raise ValueError('Query string is defined in the path and as an argument') - if query_string is None and path_s('?') in path: - path, query_string = path.split(path_s('?'), 1) + if query_string is not None and path_s("?") in path: + raise ValueError("Query string is defined in the path and as an argument") + if query_string is None and path_s("?") in path: + path, query_string = path.split(path_s("?"), 1) self.charset = charset self.path = iri_to_uri(path) if base_url is not None: @@ -375,8 +406,8 @@ class EnvironBuilder(object): if data: if input_stream is not None: - raise TypeError('can\'t provide input stream and data') - if hasattr(data, 'read'): + raise TypeError("can't provide input stream and data") + if hasattr(data, "read"): data = data.read() if isinstance(data, text_type): data = data.encode(self.charset) @@ -386,8 +417,7 @@ class EnvironBuilder(object): self.content_length = len(data) else: for key, value in _iter_data(data): - if isinstance(value, (tuple, dict)) or \ - hasattr(value, 'read'): + if isinstance(value, (tuple, dict)) or hasattr(value, "read"): self._add_file_from_data(key, value) else: self.form.setlistdefault(key).append(value) @@ -406,9 +436,7 @@ class EnvironBuilder(object): out = { "path": environ["PATH_INFO"], "base_url": cls._make_base_url( - environ["wsgi.url_scheme"], - headers.pop("Host"), - environ["SCRIPT_NAME"], + environ["wsgi.url_scheme"], headers.pop("Host"), environ["SCRIPT_NAME"] ), "query_string": environ["QUERY_STRING"], "method": environ["REQUEST_METHOD"], @@ -430,6 +458,7 @@ class EnvironBuilder(object): self.files.add_file(key, *value) elif isinstance(value, dict): from warnings import warn + warn( "Passing a dict as file data is deprecated as of" " version 0.5 and will be removed in version 1.0. Use" @@ -438,9 +467,9 @@ class EnvironBuilder(object): stacklevel=2, ) value = dict(value) - mimetype = value.pop('mimetype', None) + mimetype = value.pop("mimetype", None) if mimetype is not None: - value['content_type'] = mimetype + value["content_type"] = mimetype self.files.add_file(key, **value) else: self.files.add_file(key, value) @@ -459,90 +488,100 @@ class EnvironBuilder(object): @base_url.setter def base_url(self, value): if value is None: - scheme = 'http' - netloc = 'localhost' - script_root = '' + scheme = "http" + netloc = "localhost" + script_root = "" else: scheme, netloc, script_root, qs, anchor = url_parse(value) if qs or anchor: - raise ValueError('base url must not contain a query string ' - 'or fragment') - self.script_root = script_root.rstrip('/') + raise ValueError("base url must not contain a query string or fragment") + self.script_root = script_root.rstrip("/") self.host = netloc self.url_scheme = scheme def _get_content_type(self): - ct = self.headers.get('Content-Type') + ct = self.headers.get("Content-Type") if ct is None and not self._input_stream: if self._files: - return 'multipart/form-data' + return "multipart/form-data" elif self._form: - return 'application/x-www-form-urlencoded' + return "application/x-www-form-urlencoded" return None return ct def _set_content_type(self, value): if value is None: - self.headers.pop('Content-Type', None) + self.headers.pop("Content-Type", None) else: - self.headers['Content-Type'] = value - - content_type = property(_get_content_type, _set_content_type, doc=''' - The content type for the request. Reflected from and to the - :attr:`headers`. Do not set if you set :attr:`files` or - :attr:`form` for auto detection.''') + self.headers["Content-Type"] = value + + content_type = property( + _get_content_type, + _set_content_type, + doc="""The content type for the request. Reflected from and to + the :attr:`headers`. Do not set if you set :attr:`files` or + :attr:`form` for auto detection.""", + ) del _get_content_type, _set_content_type def _get_content_length(self): - return self.headers.get('Content-Length', type=int) + return self.headers.get("Content-Length", type=int) def _get_mimetype(self): ct = self.content_type if ct: - return ct.split(';')[0].strip() + return ct.split(";")[0].strip() def _set_mimetype(self, value): self.content_type = get_content_type(value, self.charset) def _get_mimetype_params(self): def on_update(d): - self.headers['Content-Type'] = \ - dump_options_header(self.mimetype, d) - d = parse_options_header(self.headers.get('content-type', ''))[1] + self.headers["Content-Type"] = dump_options_header(self.mimetype, d) + + d = parse_options_header(self.headers.get("content-type", ""))[1] return CallbackDict(d, on_update) - mimetype = property(_get_mimetype, _set_mimetype, doc=''' - The mimetype (content type without charset etc.) + mimetype = property( + _get_mimetype, + _set_mimetype, + doc="""The mimetype (content type without charset etc.) .. versionadded:: 0.14 - ''') - mimetype_params = property(_get_mimetype_params, doc=''' - The mimetype parameters as dict. For example if the content - type is ``text/html; charset=utf-8`` the params would be + """, + ) + mimetype_params = property( + _get_mimetype_params, + doc=""" The mimetype parameters as dict. For example if the + content type is ``text/html; charset=utf-8`` the params would be ``{'charset': 'utf-8'}``. .. versionadded:: 0.14 - ''') + """, + ) del _get_mimetype, _set_mimetype, _get_mimetype_params def _set_content_length(self, value): if value is None: - self.headers.pop('Content-Length', None) + self.headers.pop("Content-Length", None) else: - self.headers['Content-Length'] = str(value) + self.headers["Content-Length"] = str(value) - content_length = property(_get_content_length, _set_content_length, doc=''' - The content length as integer. Reflected from and to the + content_length = property( + _get_content_length, + _set_content_length, + doc="""The content length as integer. Reflected from and to the :attr:`headers`. Do not set if you set :attr:`files` or - :attr:`form` for auto detection.''') + :attr:`form` for auto detection.""", + ) del _get_content_length, _set_content_length - def form_property(name, storage, doc): - key = '_' + name + def form_property(name, storage, doc): # noqa: B902 + key = "_" + name def getter(self): if self._input_stream is not None: - raise AttributeError('an input stream is defined') + raise AttributeError("an input stream is defined") rv = getattr(self, key) if rv is None: rv = storage() @@ -553,14 +592,17 @@ class EnvironBuilder(object): def setter(self, value): self._input_stream = None setattr(self, key, value) + return property(getter, setter, doc=doc) - form = form_property('form', MultiDict, doc=''' - A :class:`MultiDict` of form values.''') - files = form_property('files', FileMultiDict, doc=''' - A :class:`FileMultiDict` of uploaded files. You can use the - :meth:`~FileMultiDict.add_file` method to add new files to the - dict.''') + form = form_property("form", MultiDict, doc="A :class:`MultiDict` of form values.") + files = form_property( + "files", + FileMultiDict, + doc="""A :class:`FileMultiDict` of uploaded files. You can use + the :meth:`~FileMultiDict.add_file` method to add new files to + the dict.""", + ) del form_property def _get_input_stream(self): @@ -570,30 +612,36 @@ class EnvironBuilder(object): self._input_stream = value self._form = self._files = None - input_stream = property(_get_input_stream, _set_input_stream, doc=''' - An optional input stream. If you set this it will clear - :attr:`form` and :attr:`files`.''') + input_stream = property( + _get_input_stream, + _set_input_stream, + doc="""An optional input stream. If you set this it will clear + :attr:`form` and :attr:`files`.""", + ) del _get_input_stream, _set_input_stream def _get_query_string(self): if self._query_string is None: if self._args is not None: return url_encode(self._args, charset=self.charset) - return '' + return "" return self._query_string def _set_query_string(self, value): self._query_string = value self._args = None - query_string = property(_get_query_string, _set_query_string, doc=''' - The query string. If you set this to a string :attr:`args` will - no longer be available.''') + query_string = property( + _get_query_string, + _set_query_string, + doc="""The query string. If you set this to a string + :attr:`args` will no longer be available.""", + ) del _get_query_string, _set_query_string def _get_args(self): if self._query_string is not None: - raise AttributeError('a query string is defined') + raise AttributeError("a query string is defined") if self._args is None: self._args = MultiDict() return self._args @@ -602,22 +650,23 @@ class EnvironBuilder(object): self._query_string = None self._args = value - args = property(_get_args, _set_args, doc=''' - The URL arguments as :class:`MultiDict`.''') + args = property( + _get_args, _set_args, doc="The URL arguments as :class:`MultiDict`." + ) del _get_args, _set_args @property def server_name(self): """The server name (read-only, use :attr:`host` to set)""" - return self.host.split(':', 1)[0] + return self.host.split(":", 1)[0] @property def server_port(self): """The server port as integer (read-only, use :attr:`host` to set)""" - pieces = self.host.split(':', 1) + pieces = self.host.split(":", 1) if len(pieces) == 2 and pieces[1].isdigit(): return int(pieces[1]) - elif self.url_scheme == 'https': + elif self.url_scheme == "https": return 443 return 80 @@ -665,15 +714,16 @@ class EnvironBuilder(object): end_pos = input_stream.tell() input_stream.seek(start_pos) content_length = end_pos - start_pos - elif mimetype == 'multipart/form-data': + elif mimetype == "multipart/form-data": values = CombinedMultiDict([self.form, self.files]) - input_stream, content_length, boundary = \ - stream_encode_multipart(values, charset=self.charset) + input_stream, content_length, boundary = stream_encode_multipart( + values, charset=self.charset + ) content_type = mimetype + '; boundary="%s"' % boundary - elif mimetype == 'application/x-www-form-urlencoded': + elif mimetype == "application/x-www-form-urlencoded": # XXX: py2v3 review values = url_encode(self.form, charset=self.charset) - values = values.encode('ascii') + values = values.encode("ascii") content_length = len(values) input_stream = BytesIO(values) else: @@ -688,40 +738,42 @@ class EnvironBuilder(object): qs = wsgi_encoding_dance(self.query_string) - result.update({ - 'REQUEST_METHOD': self.method, - 'SCRIPT_NAME': _path_encode(self.script_root), - 'PATH_INFO': _path_encode(self.path), - 'QUERY_STRING': qs, - # Non-standard, added by mod_wsgi, uWSGI - "REQUEST_URI": wsgi_encoding_dance(self.path), - # Non-standard, added by gunicorn - "RAW_URI": wsgi_encoding_dance(self.path), - 'SERVER_NAME': self.server_name, - 'SERVER_PORT': str(self.server_port), - 'HTTP_HOST': self.host, - 'SERVER_PROTOCOL': self.server_protocol, - 'wsgi.version': self.wsgi_version, - 'wsgi.url_scheme': self.url_scheme, - 'wsgi.input': input_stream, - 'wsgi.errors': self.errors_stream, - 'wsgi.multithread': self.multithread, - 'wsgi.multiprocess': self.multiprocess, - 'wsgi.run_once': self.run_once - }) + result.update( + { + "REQUEST_METHOD": self.method, + "SCRIPT_NAME": _path_encode(self.script_root), + "PATH_INFO": _path_encode(self.path), + "QUERY_STRING": qs, + # Non-standard, added by mod_wsgi, uWSGI + "REQUEST_URI": wsgi_encoding_dance(self.path), + # Non-standard, added by gunicorn + "RAW_URI": wsgi_encoding_dance(self.path), + "SERVER_NAME": self.server_name, + "SERVER_PORT": str(self.server_port), + "HTTP_HOST": self.host, + "SERVER_PROTOCOL": self.server_protocol, + "wsgi.version": self.wsgi_version, + "wsgi.url_scheme": self.url_scheme, + "wsgi.input": input_stream, + "wsgi.errors": self.errors_stream, + "wsgi.multithread": self.multithread, + "wsgi.multiprocess": self.multiprocess, + "wsgi.run_once": self.run_once, + } + ) headers = self.headers.copy() if content_type is not None: - result['CONTENT_TYPE'] = content_type + result["CONTENT_TYPE"] = content_type headers.set("Content-Type", content_type) if content_length is not None: - result['CONTENT_LENGTH'] = str(content_length) + result["CONTENT_LENGTH"] = str(content_length) headers.set("Content-Length", content_length) for key, value in headers.to_wsgi_list(): - result['HTTP_%s' % key.upper().replace('-', '_')] = value + result["HTTP_%s" % key.upper().replace("-", "_")] = value if self.environ_overrides: result.update(self.environ_overrides) @@ -740,15 +792,12 @@ class EnvironBuilder(object): class ClientRedirectError(Exception): - - """ - If a redirect loop is detected when using follow_redirects=True with + """If a redirect loop is detected when using follow_redirects=True with the :cls:`Client`, then this exception is raised. """ class Client(object): - """This class allows to send requests to a wrapped application. The response wrapper can be a class or factory function that takes @@ -781,8 +830,13 @@ class Client(object): The ``json`` parameter. """ - def __init__(self, application, response_wrapper=None, use_cookies=True, - allow_subdomain_redirects=False): + def __init__( + self, + application, + response_wrapper=None, + use_cookies=True, + allow_subdomain_redirects=False, + ): self.application = application self.response_wrapper = response_wrapper if use_cookies: @@ -791,24 +845,36 @@ class Client(object): self.cookie_jar = None self.allow_subdomain_redirects = allow_subdomain_redirects - def set_cookie(self, server_name, key, value='', max_age=None, - expires=None, path='/', domain=None, secure=None, - httponly=False, charset='utf-8'): + def set_cookie( + self, + server_name, + key, + value="", + max_age=None, + expires=None, + path="/", + domain=None, + secure=None, + httponly=False, + charset="utf-8", + ): """Sets a cookie in the client's cookie jar. The server name is required and has to match the one that is also passed to the open call. """ - assert self.cookie_jar is not None, 'cookies disabled' - header = dump_cookie(key, value, max_age, expires, path, domain, - secure, httponly, charset) - environ = create_environ(path, base_url='http://' + server_name) - headers = [('Set-Cookie', header)] + assert self.cookie_jar is not None, "cookies disabled" + header = dump_cookie( + key, value, max_age, expires, path, domain, secure, httponly, charset + ) + environ = create_environ(path, base_url="http://" + server_name) + headers = [("Set-Cookie", header)] self.cookie_jar.extract_wsgi(environ, headers) - def delete_cookie(self, server_name, key, path='/', domain=None): + def delete_cookie(self, server_name, key, path="/", domain=None): """Deletes a cookie in the test client.""" - self.set_cookie(server_name, key, expires=0, max_age=0, - path=path, domain=domain) + self.set_cookie( + server_name, key, expires=0, max_age=0, path=path, domain=domain + ) def run_wsgi_app(self, environ, buffered=False): """Runs the wrapped WSGI app with the given environment.""" @@ -826,7 +892,7 @@ class Client(object): scheme, netloc, path, qs, anchor = url_parse(new_location) builder = EnvironBuilder.from_environ(environ, query_string=qs) - to_name_parts = netloc.split(':', 1)[0].split(".") + to_name_parts = netloc.split(":", 1)[0].split(".") from_name_parts = builder.server_name.split(".") if to_name_parts != [""]: @@ -840,7 +906,7 @@ class Client(object): # Explain why a redirect to a different server name won't be followed. if to_name_parts != from_name_parts: - if to_name_parts[-len(from_name_parts):] == from_name_parts: + if to_name_parts[-len(from_name_parts) :] == from_name_parts: if not self.allow_subdomain_redirects: raise RuntimeError("Following subdomain redirects is not enabled.") else: @@ -849,9 +915,9 @@ class Client(object): path_parts = path.split("/") root_parts = builder.script_root.split("/") - if path_parts[:len(root_parts)] == root_parts: + if path_parts[: len(root_parts)] == root_parts: # Strip the script root from the path. - builder.path = path[len(builder.script_root):] + builder.path = path[len(builder.script_root) :] else: # The new location is not under the script root, so use the # whole path and clear the previous root. @@ -907,9 +973,9 @@ class Client(object): :param follow_redirects: Set this to True if the `Client` should follow HTTP redirects. """ - as_tuple = kwargs.pop('as_tuple', False) - buffered = kwargs.pop('buffered', False) - follow_redirects = kwargs.pop('follow_redirects', False) + as_tuple = kwargs.pop("as_tuple", False) + buffered = kwargs.pop("buffered", False) + follow_redirects = kwargs.pop("follow_redirects", False) environ = None if not kwargs and len(args) == 1: if isinstance(args[0], EnvironBuilder): @@ -941,15 +1007,13 @@ class Client(object): for _ in response[0]: pass - new_location = response[2]['location'] + new_location = response[2]["location"] new_redirect_entry = (new_location, status_code) if new_redirect_entry in redirect_chain: - raise ClientRedirectError('loop detected') + raise ClientRedirectError("loop detected") redirect_chain.append(new_redirect_entry) environ, response = self.resolve_redirect( - response, - new_location, - environ, buffered=buffered + response, new_location, environ, buffered=buffered ) if self.response_wrapper is not None: @@ -960,49 +1024,46 @@ class Client(object): def get(self, *args, **kw): """Like open but method is enforced to GET.""" - kw['method'] = 'GET' + kw["method"] = "GET" return self.open(*args, **kw) def patch(self, *args, **kw): """Like open but method is enforced to PATCH.""" - kw['method'] = 'PATCH' + kw["method"] = "PATCH" return self.open(*args, **kw) def post(self, *args, **kw): """Like open but method is enforced to POST.""" - kw['method'] = 'POST' + kw["method"] = "POST" return self.open(*args, **kw) def head(self, *args, **kw): """Like open but method is enforced to HEAD.""" - kw['method'] = 'HEAD' + kw["method"] = "HEAD" return self.open(*args, **kw) def put(self, *args, **kw): """Like open but method is enforced to PUT.""" - kw['method'] = 'PUT' + kw["method"] = "PUT" return self.open(*args, **kw) def delete(self, *args, **kw): """Like open but method is enforced to DELETE.""" - kw['method'] = 'DELETE' + kw["method"] = "DELETE" return self.open(*args, **kw) def options(self, *args, **kw): """Like open but method is enforced to OPTIONS.""" - kw['method'] = 'OPTIONS' + kw["method"] = "OPTIONS" return self.open(*args, **kw) def trace(self, *args, **kw): """Like open but method is enforced to TRACE.""" - kw['method'] = 'TRACE' + kw["method"] = "TRACE" return self.open(*args, **kw) def __repr__(self): - return '<%s %r>' % ( - self.__class__.__name__, - self.application - ) + return "<%s %r>" % (self.__class__.__name__, self.application) def create_environ(*args, **kwargs): @@ -1055,7 +1116,7 @@ def run_wsgi_app(app, environ, buffered=False): return buffer.append app_rv = app(environ, start_response) - close_func = getattr(app_rv, 'close', None) + close_func = getattr(app_rv, "close", None) app_iter = iter(app_rv) # when buffering we emit the close call early and convert the diff --git a/src/werkzeug/testapp.py b/src/werkzeug/testapp.py index 972c6ca1..8ea23bee 100644 --- a/src/werkzeug/testapp.py +++ b/src/werkzeug/testapp.py @@ -9,15 +9,19 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ +import base64 import os import sys -import werkzeug from textwrap import wrap -from werkzeug.wrappers import BaseRequest as Request, BaseResponse as Response -from werkzeug.utils import escape -import base64 -logo = Response(base64.b64decode(''' +import werkzeug +from .utils import escape +from .wrappers import BaseRequest as Request +from .wrappers import BaseResponse as Response + +logo = Response( + base64.b64decode( + """ R0lGODlhoACgAOMIAAEDACwpAEpCAGdgAJaKAM28AOnVAP3rAP///////// //////////////////////yH5BAEKAAgALAAAAACgAKAAAAT+EMlJq704680R+F0ojmRpnuj0rWnrv nB8rbRs33gu0bzu/0AObxgsGn3D5HHJbCUFyqZ0ukkSDlAidctNFg7gbI9LZlrBaHGtzAae0eloe25 @@ -51,10 +55,13 @@ UpyGlhjBUljyjHhWpf8OFaXwhp9O4T1gU9UeyPPa8A2l0p1kNqPXEVRm1AOs1oAGZU596t6SOR2mcB Oco1srWtkaVrMUzIErrKri85keKqRQYX9VX0/eAUK1hrSu6HMEX3Qh2sCh0q0D2CtnUqS4hj62sE/z aDs2Sg7MBS6xnQeooc2R2tC9YrKpEi9pLXfYXp20tDCpSP8rKlrD4axprb9u1Df5hSbz9QU0cRpfgn kiIzwKucd0wsEHlLpe5yHXuc6FrNelOl7pY2+11kTWx7VpRu97dXA3DO1vbkhcb4zyvERYajQgAADs -='''), mimetype='image/png') +=""" + ), + mimetype="image/png", +) -TEMPLATE = u'''\ +TEMPLATE = u"""\ <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN" "http://www.w3.org/TR/html4/loose.dtd"> <title>WSGI Information</title> @@ -130,24 +137,27 @@ TEMPLATE = u'''\ environments. <ul class="path">%(sys_path)s</ul> </div> -''' +""" def iter_sys_path(): - if os.name == 'posix': + if os.name == "posix": + def strip(x): - prefix = os.path.expanduser('~') + prefix = os.path.expanduser("~") if x.startswith(prefix): - x = '~' + x[len(prefix):] + x = "~" + x[len(prefix) :] return x + else: - strip = lambda x: x + + def strip(x): + return x cwd = os.path.abspath(os.getcwd()) for item in sys.path: path = os.path.join(cwd, item or os.path.curdir) - yield strip(os.path.normpath(path)), \ - not os.path.isdir(path), path != item + yield strip(os.path.normpath(path)), not os.path.isdir(path), path != item def render_testapp(req): @@ -156,51 +166,51 @@ def render_testapp(req): except ImportError: eggs = () else: - eggs = sorted(pkg_resources.working_set, - key=lambda x: x.project_name.lower()) + eggs = sorted(pkg_resources.working_set, key=lambda x: x.project_name.lower()) python_eggs = [] for egg in eggs: try: version = egg.version except (ValueError, AttributeError): - version = 'unknown' - python_eggs.append('<li>%s <small>[%s]</small>' % ( - escape(egg.project_name), - escape(version) - )) + version = "unknown" + python_eggs.append( + "<li>%s <small>[%s]</small>" % (escape(egg.project_name), escape(version)) + ) wsgi_env = [] - sorted_environ = sorted(req.environ.items(), - key=lambda x: repr(x[0]).lower()) + sorted_environ = sorted(req.environ.items(), key=lambda x: repr(x[0]).lower()) for key, value in sorted_environ: - wsgi_env.append('<tr><th>%s<td><code>%s</code>' % ( - escape(str(key)), - ' '.join(wrap(escape(repr(value)))) - )) + wsgi_env.append( + "<tr><th>%s<td><code>%s</code>" + % (escape(str(key)), " ".join(wrap(escape(repr(value))))) + ) sys_path = [] for item, virtual, expanded in iter_sys_path(): class_ = [] if virtual: - class_.append('virtual') + class_.append("virtual") if expanded: - class_.append('exp') - sys_path.append('<li%s>%s' % ( - ' class="%s"' % ' '.join(class_) if class_ else '', - escape(item) - )) - - return (TEMPLATE % { - 'python_version': '<br>'.join(escape(sys.version).splitlines()), - 'platform': escape(sys.platform), - 'os': escape(os.name), - 'api_version': sys.api_version, - 'byteorder': sys.byteorder, - 'werkzeug_version': werkzeug.__version__, - 'python_eggs': '\n'.join(python_eggs), - 'wsgi_env': '\n'.join(wsgi_env), - 'sys_path': '\n'.join(sys_path) - }).encode('utf-8') + class_.append("exp") + sys_path.append( + "<li%s>%s" + % (' class="%s"' % " ".join(class_) if class_ else "", escape(item)) + ) + + return ( + TEMPLATE + % { + "python_version": "<br>".join(escape(sys.version).splitlines()), + "platform": escape(sys.platform), + "os": escape(os.name), + "api_version": sys.api_version, + "byteorder": sys.byteorder, + "werkzeug_version": werkzeug.__version__, + "python_eggs": "\n".join(python_eggs), + "wsgi_env": "\n".join(wsgi_env), + "sys_path": "\n".join(sys_path), + } + ).encode("utf-8") def test_app(environ, start_response): @@ -218,13 +228,14 @@ def test_app(environ, start_response): the Python interpreter and the installed libraries. """ req = Request(environ, populate_request=False) - if req.args.get('resource') == 'logo': + if req.args.get("resource") == "logo": response = logo else: - response = Response(render_testapp(req), mimetype='text/html') + response = Response(render_testapp(req), mimetype="text/html") return response(environ, start_response) -if __name__ == '__main__': - from werkzeug.serving import run_simple - run_simple('localhost', 5000, test_app, use_reloader=True) +if __name__ == "__main__": + from .serving import run_simple + + run_simple("localhost", 5000, test_app, use_reloader=True) diff --git a/src/werkzeug/urls.py b/src/werkzeug/urls.py index 09ec5b0c..38e9e5ad 100644 --- a/src/werkzeug/urls.py +++ b/src/werkzeug/urls.py @@ -15,56 +15,53 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -from collections import namedtuple - import codecs import os import re +from collections import namedtuple -from werkzeug._compat import ( - PY2, - fix_tuple_repr, - implements_to_string, - make_literal_wrapper, - normalize_string_tuple, - text_type, - to_native, - to_unicode, - try_coerce_native, -) -from werkzeug._internal import _decode_idna, _encode_idna -from werkzeug.datastructures import MultiDict, iter_multi_items +from ._compat import fix_tuple_repr +from ._compat import implements_to_string +from ._compat import make_literal_wrapper +from ._compat import normalize_string_tuple +from ._compat import PY2 +from ._compat import text_type +from ._compat import to_native +from ._compat import to_unicode +from ._compat import try_coerce_native +from ._internal import _decode_idna +from ._internal import _encode_idna +from .datastructures import iter_multi_items +from .datastructures import MultiDict # A regular expression for what a valid schema looks like -_scheme_re = re.compile(r'^[a-zA-Z0-9+-.]+$') +_scheme_re = re.compile(r"^[a-zA-Z0-9+-.]+$") # Characters that are safe in any part of an URL. -_always_safe = frozenset(bytearray( - b"abcdefghijklmnopqrstuvwxyz" - b"ABCDEFGHIJKLMNOPQRSTUVWXYZ" - b"0123456789" - b"-._~" -)) - -_hexdigits = '0123456789ABCDEFabcdef' +_always_safe = frozenset( + bytearray( + b"abcdefghijklmnopqrstuvwxyz" + b"ABCDEFGHIJKLMNOPQRSTUVWXYZ" + b"0123456789" + b"-._~" + ) +) + +_hexdigits = "0123456789ABCDEFabcdef" _hextobyte = dict( - ((a + b).encode(), int(a + b, 16)) - for a in _hexdigits for b in _hexdigits + ((a + b).encode(), int(a + b, 16)) for a in _hexdigits for b in _hexdigits ) -_bytetohex = [ - ('%%%02X' % char).encode('ascii') for char in range(256) -] +_bytetohex = [("%%%02X" % char).encode("ascii") for char in range(256)] -_URLTuple = fix_tuple_repr(namedtuple( - '_URLTuple', - ['scheme', 'netloc', 'path', 'query', 'fragment'] -)) +_URLTuple = fix_tuple_repr( + namedtuple("_URLTuple", ["scheme", "netloc", "path", "query", "fragment"]) +) class BaseURL(_URLTuple): + """Superclass of :py:class:`URL` and :py:class:`BytesURL`.""" - '''Superclass of :py:class:`URL` and :py:class:`BytesURL`.''' __slots__ = () def replace(self, **kwargs): @@ -92,8 +89,8 @@ class BaseURL(_URLTuple): try: rv = _encode_idna(rv) except UnicodeError: - rv = rv.encode('ascii', 'ignore') - return to_native(rv, 'ascii', 'ignore') + rv = rv.encode("ascii", "ignore") + return to_native(rv, "ascii", "ignore") @property def port(self): @@ -169,19 +166,24 @@ class BaseURL(_URLTuple): def decode_netloc(self): """Decodes the netloc part into a string.""" - rv = _decode_idna(self.host or '') + rv = _decode_idna(self.host or "") - if ':' in rv: - rv = '[%s]' % rv + if ":" in rv: + rv = "[%s]" % rv port = self.port if port is not None: - rv = '%s:%d' % (rv, port) - auth = ':'.join(filter(None, [ - _url_unquote_legacy(self.raw_username or '', '/:%@'), - _url_unquote_legacy(self.raw_password or '', '/:%@'), - ])) + rv = "%s:%d" % (rv, port) + auth = ":".join( + filter( + None, + [ + _url_unquote_legacy(self.raw_username or "", "/:%@"), + _url_unquote_legacy(self.raw_password or "", "/:%@"), + ], + ) + ) if auth: - rv = '%s@%s' % (auth, rv) + rv = "%s@%s" % (auth, rv) return rv def to_uri_tuple(self): @@ -192,7 +194,7 @@ class BaseURL(_URLTuple): It's usually more interesting to directly call :meth:`iri_to_uri` which will return a string. """ - return url_parse(iri_to_uri(self).encode('ascii')) + return url_parse(iri_to_uri(self).encode("ascii")) def to_iri_tuple(self): """Returns a :class:`URL` tuple that holds a IRI. This will try @@ -223,42 +225,44 @@ class BaseURL(_URLTuple): supported. Defaults to ``None`` which is autodetect. """ - if self.scheme != 'file': + if self.scheme != "file": return None, None path = url_unquote(self.path) host = self.netloc or None if pathformat is None: - if os.name == 'nt': - pathformat = 'windows' + if os.name == "nt": + pathformat = "windows" else: - pathformat = 'posix' + pathformat = "posix" - if pathformat == 'windows': - if path[:1] == '/' and path[1:2].isalpha() and path[2:3] in '|:': - path = path[1:2] + ':' + path[3:] - windows_share = path[:3] in ('\\' * 3, '/' * 3) + if pathformat == "windows": + if path[:1] == "/" and path[1:2].isalpha() and path[2:3] in "|:": + path = path[1:2] + ":" + path[3:] + windows_share = path[:3] in ("\\" * 3, "/" * 3) import ntpath + path = ntpath.normpath(path) # Windows shared drives are represented as ``\\host\\directory``. # That results in a URL like ``file://///host/directory``, and a # path like ``///host/directory``. We need to special-case this # because the path contains the hostname. if windows_share and host is None: - parts = path.lstrip('\\').split('\\', 1) + parts = path.lstrip("\\").split("\\", 1) if len(parts) == 2: host, path = parts else: host = parts[0] - path = '' - elif pathformat == 'posix': + path = "" + elif pathformat == "posix": import posixpath + path = posixpath.normpath(path) else: - raise TypeError('Invalid path format %s' % repr(pathformat)) + raise TypeError("Invalid path format %s" % repr(pathformat)) - if host in ('127.0.0.1', '::1', 'localhost'): + if host in ("127.0.0.1", "::1", "localhost"): host = None return host, path @@ -291,7 +295,7 @@ class BaseURL(_URLTuple): return rv, None host = rv[1:idx] - rest = rv[idx + 1:] + rest = rv[idx + 1 :] if rest.startswith(self._colon): return host, rest[1:] return host, None @@ -299,81 +303,84 @@ class BaseURL(_URLTuple): @implements_to_string class URL(BaseURL): - """Represents a parsed URL. This behaves like a regular tuple but also has some extra attributes that give further insight into the URL. """ + __slots__ = () - _at = '@' - _colon = ':' - _lbracket = '[' - _rbracket = ']' + _at = "@" + _colon = ":" + _lbracket = "[" + _rbracket = "]" def __str__(self): return self.to_url() def encode_netloc(self): """Encodes the netloc part to an ASCII safe URL as bytes.""" - rv = self.ascii_host or '' - if ':' in rv: - rv = '[%s]' % rv + rv = self.ascii_host or "" + if ":" in rv: + rv = "[%s]" % rv port = self.port if port is not None: - rv = '%s:%d' % (rv, port) - auth = ':'.join(filter(None, [ - url_quote(self.raw_username or '', 'utf-8', 'strict', '/:%'), - url_quote(self.raw_password or '', 'utf-8', 'strict', '/:%'), - ])) + rv = "%s:%d" % (rv, port) + auth = ":".join( + filter( + None, + [ + url_quote(self.raw_username or "", "utf-8", "strict", "/:%"), + url_quote(self.raw_password or "", "utf-8", "strict", "/:%"), + ], + ) + ) if auth: - rv = '%s@%s' % (auth, rv) + rv = "%s@%s" % (auth, rv) return to_native(rv) - def encode(self, charset='utf-8', errors='replace'): + def encode(self, charset="utf-8", errors="replace"): """Encodes the URL to a tuple made out of bytes. The charset is only being used for the path, query and fragment. """ return BytesURL( - self.scheme.encode('ascii'), + self.scheme.encode("ascii"), self.encode_netloc(), self.path.encode(charset, errors), self.query.encode(charset, errors), - self.fragment.encode(charset, errors) + self.fragment.encode(charset, errors), ) class BytesURL(BaseURL): - """Represents a parsed URL in bytes.""" + __slots__ = () - _at = b'@' - _colon = b':' - _lbracket = b'[' - _rbracket = b']' + _at = b"@" + _colon = b":" + _lbracket = b"[" + _rbracket = b"]" def __str__(self): - return self.to_url().decode('utf-8', 'replace') + return self.to_url().decode("utf-8", "replace") def encode_netloc(self): """Returns the netloc unchanged as bytes.""" return self.netloc - def decode(self, charset='utf-8', errors='replace'): + def decode(self, charset="utf-8", errors="replace"): """Decodes the URL to a tuple made out of strings. The charset is only being used for the path, query and fragment. """ return URL( - self.scheme.decode('ascii'), + self.scheme.decode("ascii"), self.decode_netloc(), self.path.decode(charset, errors), self.query.decode(charset, errors), - self.fragment.decode(charset, errors) + self.fragment.decode(charset, errors), ) -_unquote_maps = { - frozenset(): _hextobyte, -} +_unquote_maps = {frozenset(): _hextobyte} def _unquote_to_bytes(string, unsafe=""): @@ -418,15 +425,14 @@ def _url_encode_impl(obj, charset, encode_keys, sort, key): key = text_type(key).encode(charset) if not isinstance(value, bytes): value = text_type(value).encode(charset) - yield _fast_url_quote_plus(key) + '=' + _fast_url_quote_plus(value) + yield _fast_url_quote_plus(key) + "=" + _fast_url_quote_plus(value) -def _url_unquote_legacy(value, unsafe=''): +def _url_unquote_legacy(value, unsafe=""): try: - return url_unquote(value, charset='utf-8', - errors='strict', unsafe=unsafe) + return url_unquote(value, charset="utf-8", errors="strict", unsafe=unsafe) except UnicodeError: - return url_unquote(value, charset='latin1', unsafe=unsafe) + return url_unquote(value, charset="latin1", unsafe=unsafe) def url_parse(url, scheme=None, allow_fragments=True): @@ -446,38 +452,39 @@ def url_parse(url, scheme=None, allow_fragments=True): is_text_based = isinstance(url, text_type) if scheme is None: - scheme = s('') - netloc = query = fragment = s('') - i = url.find(s(':')) - if i > 0 and _scheme_re.match(to_native(url[:i], errors='replace')): + scheme = s("") + netloc = query = fragment = s("") + i = url.find(s(":")) + if i > 0 and _scheme_re.match(to_native(url[:i], errors="replace")): # make sure "iri" is not actually a port number (in which case # "scheme" is really part of the path) - rest = url[i + 1:] - if not rest or any(c not in s('0123456789') for c in rest): + rest = url[i + 1 :] + if not rest or any(c not in s("0123456789") for c in rest): # not a port number scheme, url = url[:i].lower(), rest - if url[:2] == s('//'): + if url[:2] == s("//"): delim = len(url) - for c in s('/?#'): + for c in s("/?#"): wdelim = url.find(c, 2) if wdelim >= 0: delim = min(delim, wdelim) netloc, url = url[2:delim], url[delim:] - if (s('[') in netloc and s(']') not in netloc) or \ - (s(']') in netloc and s('[') not in netloc): - raise ValueError('Invalid IPv6 URL') + if (s("[") in netloc and s("]") not in netloc) or ( + s("]") in netloc and s("[") not in netloc + ): + raise ValueError("Invalid IPv6 URL") - if allow_fragments and s('#') in url: - url, fragment = url.split(s('#'), 1) - if s('?') in url: - url, query = url.split(s('?'), 1) + if allow_fragments and s("#") in url: + url, fragment = url.split(s("#"), 1) + if s("?") in url: + url, query = url.split(s("?"), 1) result_type = URL if is_text_based else BytesURL return result_type(scheme, netloc, url, query, fragment) -def _make_fast_url_quote(charset='utf-8', errors='strict', safe='/:', unsafe=''): +def _make_fast_url_quote(charset="utf-8", errors="strict", safe="/:", unsafe=""): """Precompile the translation table for a URL encoding function. Unlike :func:`url_quote`, the generated function only takes the @@ -495,12 +502,15 @@ def _make_fast_url_quote(charset='utf-8', errors='strict', safe='/:', unsafe='') unsafe = unsafe.encode(charset, errors) safe = (frozenset(bytearray(safe)) | _always_safe) - frozenset(bytearray(unsafe)) - table = [chr(c) if c in safe else '%%%02X' % c for c in range(256)] + table = [chr(c) if c in safe else "%%%02X" % c for c in range(256)] if not PY2: + def quote(string): return "".join([table[c] for c in string]) + else: + def quote(string): return "".join([table[c] for c in bytearray(string)]) @@ -508,14 +518,14 @@ def _make_fast_url_quote(charset='utf-8', errors='strict', safe='/:', unsafe='') _fast_url_quote = _make_fast_url_quote() -_fast_quote_plus = _make_fast_url_quote(safe=' ', unsafe='+') +_fast_quote_plus = _make_fast_url_quote(safe=" ", unsafe="+") def _fast_url_quote_plus(string): - return _fast_quote_plus(string).replace(' ', '+') + return _fast_quote_plus(string).replace(" ", "+") -def url_quote(string, charset='utf-8', errors='strict', safe='/:', unsafe=''): +def url_quote(string, charset="utf-8", errors="strict", safe="/:", unsafe=""): """URL encode a single string with a given encoding. :param s: the string to quote. @@ -544,7 +554,7 @@ def url_quote(string, charset='utf-8', errors='strict', safe='/:', unsafe=''): return to_native(bytes(rv)) -def url_quote_plus(string, charset='utf-8', errors='strict', safe=''): +def url_quote_plus(string, charset="utf-8", errors="strict", safe=""): """URL encode a single string with the given encoding and convert whitespace to "+". @@ -552,7 +562,7 @@ def url_quote_plus(string, charset='utf-8', errors='strict', safe=''): :param charset: The charset to be used. :param safe: An optional sequence of safe characters. """ - return url_quote(string, charset, errors, safe + ' ', '+').replace(' ', '+') + return url_quote(string, charset, errors, safe + " ", "+").replace(" ", "+") def url_unparse(components): @@ -562,31 +572,30 @@ def url_unparse(components): :param components: the parsed URL as tuple which should be converted into a URL string. """ - scheme, netloc, path, query, fragment = \ - normalize_string_tuple(components) + scheme, netloc, path, query, fragment = normalize_string_tuple(components) s = make_literal_wrapper(scheme) - url = s('') + url = s("") # We generally treat file:///x and file:/x the same which is also # what browsers seem to do. This also allows us to ignore a schema # register for netloc utilization or having to differenciate between # empty and missing netloc. - if netloc or (scheme and path.startswith(s('/'))): - if path and path[:1] != s('/'): - path = s('/') + path - url = s('//') + (netloc or s('')) + path + if netloc or (scheme and path.startswith(s("/"))): + if path and path[:1] != s("/"): + path = s("/") + path + url = s("//") + (netloc or s("")) + path elif path: url += path if scheme: - url = scheme + s(':') + url + url = scheme + s(":") + url if query: - url = url + s('?') + query + url = url + s("?") + query if fragment: - url = url + s('#') + fragment + url = url + s("#") + fragment return url -def url_unquote(string, charset='utf-8', errors='replace', unsafe=''): +def url_unquote(string, charset="utf-8", errors="replace", unsafe=""): """URL decode a single string with a given encoding. If the charset is set to `None` no unicode decoding is performed and raw bytes are returned. @@ -602,7 +611,7 @@ def url_unquote(string, charset='utf-8', errors='replace', unsafe=''): return rv -def url_unquote_plus(s, charset='utf-8', errors='replace'): +def url_unquote_plus(s, charset="utf-8", errors="replace"): """URL decode a single string with the given `charset` and decode "+" to whitespace. @@ -616,13 +625,13 @@ def url_unquote_plus(s, charset='utf-8', errors='replace'): :param errors: The error handling for the `charset` decoding. """ if isinstance(s, text_type): - s = s.replace(u'+', u' ') + s = s.replace(u"+", u" ") else: - s = s.replace(b'+', b' ') + s = s.replace(b"+", b" ") return url_unquote(s, charset, errors) -def url_fix(s, charset='utf-8'): +def url_fix(s, charset="utf-8"): r"""Sometimes you get an URL by a user that just isn't a real URL because it contains unsafe characters like ' ' and so on. This function can fix some of the problems in a similar way browsers handle data entered by the @@ -638,19 +647,18 @@ def url_fix(s, charset='utf-8'): # First step is to switch to unicode processing and to convert # backslashes (which are invalid in URLs anyways) to slashes. This is # consistent with what Chrome does. - s = to_unicode(s, charset, 'replace').replace('\\', '/') + s = to_unicode(s, charset, "replace").replace("\\", "/") # For the specific case that we look like a malformed windows URL # we want to fix this up manually: - if s.startswith('file://') and s[7:8].isalpha() and s[8:10] in (':/', '|/'): - s = 'file:///' + s[7:] + if s.startswith("file://") and s[7:8].isalpha() and s[8:10] in (":/", "|/"): + s = "file:///" + s[7:] url = url_parse(s) - path = url_quote(url.path, charset, safe='/%+$!*\'(),') - qs = url_quote_plus(url.query, charset, safe=':&%=+$!*\'(),') - anchor = url_quote_plus(url.fragment, charset, safe=':&%=+$!*\'(),') - return to_native(url_unparse((url.scheme, url.encode_netloc(), - path, qs, anchor))) + path = url_quote(url.path, charset, safe="/%+$!*'(),") + qs = url_quote_plus(url.query, charset, safe=":&%=+$!*'(),") + anchor = url_quote_plus(url.fragment, charset, safe=":&%=+$!*'(),") + return to_native(url_unparse((url.scheme, url.encode_netloc(), path, qs, anchor))) # not-unreserved characters remain quoted when unquoting to IRI @@ -661,7 +669,7 @@ def _codec_error_url_quote(e): """Used in :func:`uri_to_iri` after unquoting to re-quote any invalid bytes. """ - out = _fast_url_quote(e.object[e.start:e.end]) + out = _fast_url_quote(e.object[e.start : e.end]) if PY2: out = out.decode("utf-8") @@ -752,7 +760,7 @@ def iri_to_uri(iri, charset="utf-8", errors="strict", safe_conversion=False): # contains ASCII characters, return it unconverted. try: native_iri = to_native(iri) - ascii_iri = native_iri.encode('ascii') + ascii_iri = native_iri.encode("ascii") # Only return if it doesn't have whitespace. (Why?) if len(ascii_iri.split()) == 1: @@ -769,8 +777,15 @@ def iri_to_uri(iri, charset="utf-8", errors="strict", safe_conversion=False): ) -def url_decode(s, charset='utf-8', decode_keys=False, include_empty=True, - errors='replace', separator='&', cls=None): +def url_decode( + s, + charset="utf-8", + decode_keys=False, + include_empty=True, + errors="replace", + separator="&", + cls=None, +): """ Parse a querystring and return it as :class:`MultiDict`. There is a difference in key decoding on different Python versions. On Python 3 @@ -812,16 +827,27 @@ def url_decode(s, charset='utf-8', decode_keys=False, include_empty=True, if cls is None: cls = MultiDict if isinstance(s, text_type) and not isinstance(separator, text_type): - separator = separator.decode(charset or 'ascii') + separator = separator.decode(charset or "ascii") elif isinstance(s, bytes) and not isinstance(separator, bytes): - separator = separator.encode(charset or 'ascii') - return cls(_url_decode_impl(s.split(separator), charset, decode_keys, - include_empty, errors)) + separator = separator.encode(charset or "ascii") + return cls( + _url_decode_impl( + s.split(separator), charset, decode_keys, include_empty, errors + ) + ) -def url_decode_stream(stream, charset='utf-8', decode_keys=False, - include_empty=True, errors='replace', separator='&', - cls=None, limit=None, return_iterator=False): +def url_decode_stream( + stream, + charset="utf-8", + decode_keys=False, + include_empty=True, + errors="replace", + separator="&", + cls=None, + limit=None, + return_iterator=False, +): """Works like :func:`url_decode` but decodes a stream. The behavior of stream and limit follows functions like :func:`~werkzeug.wsgi.make_line_iter`. The generator of pairs is @@ -849,14 +875,18 @@ def url_decode_stream(stream, charset='utf-8', decode_keys=False, and an iterator over all decoded pairs is returned """ - from werkzeug.wsgi import make_chunk_iter + from .wsgi import make_chunk_iter + + pair_iter = make_chunk_iter(stream, separator, limit) + decoder = _url_decode_impl(pair_iter, charset, decode_keys, include_empty, errors) + if return_iterator: - cls = lambda x: x - elif cls is None: + return decoder + + if cls is None: cls = MultiDict - pair_iter = make_chunk_iter(stream, separator, limit) - return cls(_url_decode_impl(pair_iter, charset, decode_keys, - include_empty, errors)) + + return cls(decoder) def _url_decode_impl(pair_iter, charset, decode_keys, include_empty, errors): @@ -864,22 +894,23 @@ def _url_decode_impl(pair_iter, charset, decode_keys, include_empty, errors): if not pair: continue s = make_literal_wrapper(pair) - equal = s('=') + equal = s("=") if equal in pair: key, value = pair.split(equal, 1) else: if not include_empty: continue key = pair - value = s('') + value = s("") key = url_unquote_plus(key, charset, errors) if charset is not None and PY2 and not decode_keys: key = try_coerce_native(key) yield key, url_unquote_plus(value, charset, errors) -def url_encode(obj, charset='utf-8', encode_keys=False, sort=False, key=None, - separator=b'&'): +def url_encode( + obj, charset="utf-8", encode_keys=False, sort=False, key=None, separator=b"&" +): """URL encode a dict/`MultiDict`. If a value is `None` it will not appear in the result string. Per default only values are encoded into the target charset strings. If `encode_keys` is set to ``True`` unicode keys are @@ -900,12 +931,19 @@ def url_encode(obj, charset='utf-8', encode_keys=False, sort=False, key=None, :param key: an optional function to be used for sorting. For more details check out the :func:`sorted` documentation. """ - separator = to_native(separator, 'ascii') + separator = to_native(separator, "ascii") return separator.join(_url_encode_impl(obj, charset, encode_keys, sort, key)) -def url_encode_stream(obj, stream=None, charset='utf-8', encode_keys=False, - sort=False, key=None, separator=b'&'): +def url_encode_stream( + obj, + stream=None, + charset="utf-8", + encode_keys=False, + sort=False, + key=None, + separator=b"&", +): """Like :meth:`url_encode` but writes the results to a stream object. If the stream is `None` a generator over all encoded pairs is returned. @@ -924,7 +962,7 @@ def url_encode_stream(obj, stream=None, charset='utf-8', encode_keys=False, :param key: an optional function to be used for sorting. For more details check out the :func:`sorted` documentation. """ - separator = to_native(separator, 'ascii') + separator = to_native(separator, "ascii") gen = _url_encode_impl(obj, charset, encode_keys, sort, key) if stream is None: return gen @@ -955,55 +993,53 @@ def url_join(base, url, allow_fragments=True): if not url: return base - bscheme, bnetloc, bpath, bquery, bfragment = \ - url_parse(base, allow_fragments=allow_fragments) - scheme, netloc, path, query, fragment = \ - url_parse(url, bscheme, allow_fragments) + bscheme, bnetloc, bpath, bquery, bfragment = url_parse( + base, allow_fragments=allow_fragments + ) + scheme, netloc, path, query, fragment = url_parse(url, bscheme, allow_fragments) if scheme != bscheme: return url if netloc: return url_unparse((scheme, netloc, path, query, fragment)) netloc = bnetloc - if path[:1] == s('/'): - segments = path.split(s('/')) + if path[:1] == s("/"): + segments = path.split(s("/")) elif not path: - segments = bpath.split(s('/')) + segments = bpath.split(s("/")) if not query: query = bquery else: - segments = bpath.split(s('/'))[:-1] + path.split(s('/')) + segments = bpath.split(s("/"))[:-1] + path.split(s("/")) # If the rightmost part is "./" we want to keep the slash but # remove the dot. - if segments[-1] == s('.'): - segments[-1] = s('') + if segments[-1] == s("."): + segments[-1] = s("") # Resolve ".." and "." - segments = [segment for segment in segments if segment != s('.')] + segments = [segment for segment in segments if segment != s(".")] while 1: i = 1 n = len(segments) - 1 while i < n: - if segments[i] == s('..') and \ - segments[i - 1] not in (s(''), s('..')): - del segments[i - 1:i + 1] + if segments[i] == s("..") and segments[i - 1] not in (s(""), s("..")): + del segments[i - 1 : i + 1] break i += 1 else: break # Remove trailing ".." if the URL is absolute - unwanted_marker = [s(''), s('..')] + unwanted_marker = [s(""), s("..")] while segments[:2] == unwanted_marker: del segments[1] - path = s('/').join(segments) + path = s("/").join(segments) return url_unparse((scheme, netloc, path, query, fragment)) class Href(object): - """Implements a callable that constructs URLs with the given base. The function can be called with any number of positional and keyword arguments which than are used to assemble the URL. Works with URLs @@ -1054,39 +1090,45 @@ class Href(object): `sort` and `key` were added. """ - def __init__(self, base='./', charset='utf-8', sort=False, key=None): + def __init__(self, base="./", charset="utf-8", sort=False, key=None): if not base: - base = './' + base = "./" self.base = base self.charset = charset self.sort = sort self.key = key def __getattr__(self, name): - if name[:2] == '__': + if name[:2] == "__": raise AttributeError(name) base = self.base - if base[-1:] != '/': - base += '/' + if base[-1:] != "/": + base += "/" return Href(url_join(base, name), self.charset, self.sort, self.key) def __call__(self, *path, **query): if path and isinstance(path[-1], dict): if query: - raise TypeError('keyword arguments and query-dicts ' - 'can\'t be combined') + raise TypeError("keyword arguments and query-dicts can't be combined") query, path = path[-1], path[:-1] elif query: - query = dict([(k.endswith('_') and k[:-1] or k, v) - for k, v in query.items()]) - path = '/'.join([to_unicode(url_quote(x, self.charset), 'ascii') - for x in path if x is not None]).lstrip('/') + query = dict( + [(k.endswith("_") and k[:-1] or k, v) for k, v in query.items()] + ) + path = "/".join( + [ + to_unicode(url_quote(x, self.charset), "ascii") + for x in path + if x is not None + ] + ).lstrip("/") rv = self.base if path: - if not rv.endswith('/'): - rv += '/' - rv = url_join(rv, './' + path) + if not rv.endswith("/"): + rv += "/" + rv = url_join(rv, "./" + path) if query: - rv += '?' + to_unicode(url_encode(query, self.charset, sort=self.sort, - key=self.key), 'ascii') + rv += "?" + to_unicode( + url_encode(query, self.charset, sort=self.sort, key=self.key), "ascii" + ) return to_native(rv) diff --git a/src/werkzeug/useragents.py b/src/werkzeug/useragents.py index 90d2649b..e265e093 100644 --- a/src/werkzeug/useragents.py +++ b/src/werkzeug/useragents.py @@ -12,80 +12,82 @@ :license: BSD-3-Clause """ import re +import warnings class UserAgentParser(object): - """A simple user agent parser. Used by the `UserAgent`.""" platforms = ( - ('cros', 'chromeos'), - ('iphone|ios', 'iphone'), - ('ipad', 'ipad'), - (r'darwin|mac|os\s*x', 'macos'), - ('win', 'windows'), - (r'android', 'android'), - ('netbsd', 'netbsd'), - ('openbsd', 'openbsd'), - ('freebsd', 'freebsd'), - ('dragonfly', 'dragonflybsd'), - ('(sun|i86)os', 'solaris'), - (r'x11|lin(\b|ux)?', 'linux'), - (r'nintendo\s+wii', 'wii'), - ('irix', 'irix'), - ('hp-?ux', 'hpux'), - ('aix', 'aix'), - ('sco|unix_sv', 'sco'), - ('bsd', 'bsd'), - ('amiga', 'amiga'), - ('blackberry|playbook', 'blackberry'), - ('symbian', 'symbian') + ("cros", "chromeos"), + ("iphone|ios", "iphone"), + ("ipad", "ipad"), + (r"darwin|mac|os\s*x", "macos"), + ("win", "windows"), + (r"android", "android"), + ("netbsd", "netbsd"), + ("openbsd", "openbsd"), + ("freebsd", "freebsd"), + ("dragonfly", "dragonflybsd"), + ("(sun|i86)os", "solaris"), + (r"x11|lin(\b|ux)?", "linux"), + (r"nintendo\s+wii", "wii"), + ("irix", "irix"), + ("hp-?ux", "hpux"), + ("aix", "aix"), + ("sco|unix_sv", "sco"), + ("bsd", "bsd"), + ("amiga", "amiga"), + ("blackberry|playbook", "blackberry"), + ("symbian", "symbian"), ) browsers = ( - ('googlebot', 'google'), - ('msnbot', 'msn'), - ('yahoo', 'yahoo'), - ('ask jeeves', 'ask'), - (r'aol|america\s+online\s+browser', 'aol'), - ('opera', 'opera'), - ('edge', 'edge'), - ('chrome|crios', 'chrome'), - ('seamonkey', 'seamonkey'), - ('firefox|firebird|phoenix|iceweasel', 'firefox'), - ('galeon', 'galeon'), - ('safari|version', 'safari'), - ('webkit', 'webkit'), - ('camino', 'camino'), - ('konqueror', 'konqueror'), - ('k-meleon', 'kmeleon'), - ('netscape', 'netscape'), - (r'msie|microsoft\s+internet\s+explorer|trident/.+? rv:', 'msie'), - ('lynx', 'lynx'), - ('links', 'links'), - ('Baiduspider', 'baidu'), - ('bingbot', 'bing'), - ('mozilla', 'mozilla') + ("googlebot", "google"), + ("msnbot", "msn"), + ("yahoo", "yahoo"), + ("ask jeeves", "ask"), + (r"aol|america\s+online\s+browser", "aol"), + ("opera", "opera"), + ("edge", "edge"), + ("chrome|crios", "chrome"), + ("seamonkey", "seamonkey"), + ("firefox|firebird|phoenix|iceweasel", "firefox"), + ("galeon", "galeon"), + ("safari|version", "safari"), + ("webkit", "webkit"), + ("camino", "camino"), + ("konqueror", "konqueror"), + ("k-meleon", "kmeleon"), + ("netscape", "netscape"), + (r"msie|microsoft\s+internet\s+explorer|trident/.+? rv:", "msie"), + ("lynx", "lynx"), + ("links", "links"), + ("Baiduspider", "baidu"), + ("bingbot", "bing"), + ("mozilla", "mozilla"), ) - _browser_version_re = r'(?:%s)[/\sa-z(]*(\d+[.\da-z]+)?' + _browser_version_re = r"(?:%s)[/\sa-z(]*(\d+[.\da-z]+)?" _language_re = re.compile( - r'(?:;\s*|\s+)(\b\w{2}\b(?:-\b\w{2}\b)?)\s*;|' - r'(?:\(|\[|;)\s*(\b\w{2}\b(?:-\b\w{2}\b)?)\s*(?:\]|\)|;)' + r"(?:;\s*|\s+)(\b\w{2}\b(?:-\b\w{2}\b)?)\s*;|" + r"(?:\(|\[|;)\s*(\b\w{2}\b(?:-\b\w{2}\b)?)\s*(?:\]|\)|;)" ) def __init__(self): self.platforms = [(b, re.compile(a, re.I)) for a, b in self.platforms] - self.browsers = [(b, re.compile(self._browser_version_re % a, re.I)) - for a, b in self.browsers] + self.browsers = [ + (b, re.compile(self._browser_version_re % a, re.I)) + for a, b in self.browsers + ] def __call__(self, user_agent): - for platform, regex in self.platforms: + for platform, regex in self.platforms: # noqa: B007 match = regex.search(user_agent) if match is not None: break else: platform = None - for browser, regex in self.browsers: + for browser, regex in self.browsers: # noqa: B007 match = regex.search(user_agent) if match is not None: version = match.group(1) @@ -101,7 +103,6 @@ class UserAgentParser(object): class UserAgent(object): - """Represents a user agent. Pass it a WSGI environment or a user agent string and you can inspect some of the details from the user agent string via the attributes. The following attributes exist: @@ -181,10 +182,11 @@ class UserAgent(object): def __init__(self, environ_or_string): if isinstance(environ_or_string, dict): - environ_or_string = environ_or_string.get('HTTP_USER_AGENT', '') + environ_or_string = environ_or_string.get("HTTP_USER_AGENT", "") self.string = environ_or_string - self.platform, self.browser, self.version, self.language = \ - self._parser(environ_or_string) + self.platform, self.browser, self.version, self.language = self._parser( + environ_or_string + ) def to_header(self): return self.string @@ -198,16 +200,21 @@ class UserAgent(object): __bool__ = __nonzero__ def __repr__(self): - return '<%s %r/%s>' % ( - self.__class__.__name__, - self.browser, - self.version - ) + return "<%s %r/%s>" % (self.__class__.__name__, self.browser, self.version) -# conceptionally this belongs in this module but because we want to lazily -# load the user agent module (which happens in wrappers.py) we have to import -# it afterwards. The class itself has the module set to this module so -# pickle, inspect and similar modules treat the object as if it was really -# implemented here. -from werkzeug.wrappers import UserAgentMixin # noqa +# DEPRECATED +from .wrappers import UserAgentMixin as _UserAgentMixin + + +class UserAgentMixin(_UserAgentMixin): + @property + def user_agent(self, *args, **kwargs): + warnings.warn( + "'werkzeug.useragents.UserAgentMixin' should be imported" + " from 'werkzeug.wrappers.UserAgentMixin'. This old import" + " will be removed in version 1.0.", + DeprecationWarning, + stacklevel=2, + ) + return super(_UserAgentMixin, self).user_agent diff --git a/src/werkzeug/utils.py b/src/werkzeug/utils.py index 406f62b3..20620572 100644 --- a/src/werkzeug/utils.py +++ b/src/werkzeug/utils.py @@ -11,32 +11,47 @@ :license: BSD-3-Clause """ import codecs -import re import os -import sys import pkgutil +import re +import sys import warnings +from ._compat import iteritems +from ._compat import PY2 +from ._compat import reraise +from ._compat import string_types +from ._compat import text_type +from ._compat import unichr +from ._internal import _DictAccessorProperty +from ._internal import _missing +from ._internal import _parse_signature + try: from html.entities import name2codepoint except ImportError: from htmlentitydefs import name2codepoint -from werkzeug._compat import unichr, text_type, string_types, iteritems, \ - reraise, PY2 -from werkzeug._internal import _DictAccessorProperty, \ - _parse_signature, _missing - -_format_re = re.compile(r'\$(?:(%s)|\{(%s)\})' % (('[a-zA-Z_][a-zA-Z0-9_]*',) * 2)) -_entity_re = re.compile(r'&([^;]+);') -_filename_ascii_strip_re = re.compile(r'[^A-Za-z0-9_.-]') -_windows_device_files = ('CON', 'AUX', 'COM1', 'COM2', 'COM3', 'COM4', 'LPT1', - 'LPT2', 'LPT3', 'PRN', 'NUL') +_format_re = re.compile(r"\$(?:(%s)|\{(%s)\})" % (("[a-zA-Z_][a-zA-Z0-9_]*",) * 2)) +_entity_re = re.compile(r"&([^;]+);") +_filename_ascii_strip_re = re.compile(r"[^A-Za-z0-9_.-]") +_windows_device_files = ( + "CON", + "AUX", + "COM1", + "COM2", + "COM3", + "COM4", + "LPT1", + "LPT2", + "LPT3", + "PRN", + "NUL", +) class cached_property(property): - """A decorator that converts a function into a lazy property. The function wrapped is called the first time to retrieve the result and then that calculated result is used the next time you access @@ -79,7 +94,6 @@ class cached_property(property): class environ_property(_DictAccessorProperty): - """Maps request attributes to environment variables. This works not only for the Werzeug request object, but also any other class with an environ attribute: @@ -107,7 +121,6 @@ class environ_property(_DictAccessorProperty): class header_property(_DictAccessorProperty): - """Like `environ_property` but for headers.""" def lookup(self, obj): @@ -115,7 +128,6 @@ class header_property(_DictAccessorProperty): class HTMLBuilder(object): - """Helper object for HTML generation. Per default there are two instances of that class. The `html` one, and @@ -141,20 +153,45 @@ class HTMLBuilder(object): u'<p><foo></p>' """ - _entity_re = re.compile(r'&([^;]+);') + _entity_re = re.compile(r"&([^;]+);") _entities = name2codepoint.copy() - _entities['apos'] = 39 - _empty_elements = set([ - 'area', 'base', 'basefont', 'br', 'col', 'command', 'embed', 'frame', - 'hr', 'img', 'input', 'keygen', 'isindex', 'link', 'meta', 'param', - 'source', 'wbr' - ]) - _boolean_attributes = set([ - 'selected', 'checked', 'compact', 'declare', 'defer', 'disabled', - 'ismap', 'multiple', 'nohref', 'noresize', 'noshade', 'nowrap' - ]) - _plaintext_elements = set(['textarea']) - _c_like_cdata = set(['script', 'style']) + _entities["apos"] = 39 + _empty_elements = { + "area", + "base", + "basefont", + "br", + "col", + "command", + "embed", + "frame", + "hr", + "img", + "input", + "keygen", + "isindex", + "link", + "meta", + "param", + "source", + "wbr", + } + _boolean_attributes = { + "selected", + "checked", + "compact", + "declare", + "defer", + "disabled", + "ismap", + "multiple", + "nohref", + "noresize", + "noshade", + "nowrap", + } + _plaintext_elements = {"textarea"} + _c_like_cdata = {"script", "style"} def __init__(self, dialect): self._dialect = dialect @@ -163,56 +200,56 @@ class HTMLBuilder(object): return escape(s) def __getattr__(self, tag): - if tag[:2] == '__': + if tag[:2] == "__": raise AttributeError(tag) def proxy(*children, **arguments): - buffer = '<' + tag + buffer = "<" + tag for key, value in iteritems(arguments): if value is None: continue - if key[-1] == '_': + if key[-1] == "_": key = key[:-1] if key in self._boolean_attributes: if not value: continue - if self._dialect == 'xhtml': + if self._dialect == "xhtml": value = '="' + key + '"' else: - value = '' + value = "" else: value = '="' + escape(value) + '"' - buffer += ' ' + key + value + buffer += " " + key + value if not children and tag in self._empty_elements: - if self._dialect == 'xhtml': - buffer += ' />' + if self._dialect == "xhtml": + buffer += " />" else: - buffer += '>' + buffer += ">" return buffer - buffer += '>' + buffer += ">" - children_as_string = ''.join([text_type(x) for x in children - if x is not None]) + children_as_string = "".join( + [text_type(x) for x in children if x is not None] + ) if children_as_string: if tag in self._plaintext_elements: children_as_string = escape(children_as_string) - elif tag in self._c_like_cdata and self._dialect == 'xhtml': - children_as_string = '/*<![CDATA[*/' + \ - children_as_string + '/*]]>*/' - buffer += children_as_string + '</' + tag + '>' + elif tag in self._c_like_cdata and self._dialect == "xhtml": + children_as_string = ( + "/*<![CDATA[*/" + children_as_string + "/*]]>*/" + ) + buffer += children_as_string + "</" + tag + ">" return buffer + return proxy def __repr__(self): - return '<%s for %r>' % ( - self.__class__.__name__, - self._dialect - ) + return "<%s for %r>" % (self.__class__.__name__, self._dialect) -html = HTMLBuilder('html') -xhtml = HTMLBuilder('xhtml') +html = HTMLBuilder("html") +xhtml = HTMLBuilder("xhtml") # https://cgit.freedesktop.org/xdg/shared-mime-info/tree/freedesktop.org.xml.in # https://www.iana.org/assignments/media-types/media-types.xhtml @@ -243,11 +280,11 @@ def get_content_type(mimetype, charset): ``application/javascript`` are also given charsets. """ if ( - mimetype.startswith('text/') + mimetype.startswith("text/") or mimetype in _charset_mimetypes - or mimetype.endswith('+xml') + or mimetype.endswith("+xml") ): - mimetype += '; charset=' + charset + mimetype += "; charset=" + charset return mimetype @@ -294,7 +331,7 @@ def detect_utf_encoding(data): return "utf-16-le" if len(head) == 2: - return "utf-16-be" if head.startswith(b'\x00') else "utf-16-le" + return "utf-16-be" if head.startswith(b"\x00") else "utf-16-le" return "utf-8" @@ -311,11 +348,13 @@ def format_string(string, context): :param string: the format string. :param context: a dict with the variables to insert. """ + def lookup_arg(match): x = context[match.group(1) or match.group(2)] if not isinstance(x, string_types): x = type(string)(x) return x + return _format_re.sub(lookup_arg, string) @@ -345,21 +384,26 @@ def secure_filename(filename): """ if isinstance(filename, text_type): from unicodedata import normalize - filename = normalize('NFKD', filename).encode('ascii', 'ignore') + + filename = normalize("NFKD", filename).encode("ascii", "ignore") if not PY2: - filename = filename.decode('ascii') + filename = filename.decode("ascii") for sep in os.path.sep, os.path.altsep: if sep: - filename = filename.replace(sep, ' ') - filename = str(_filename_ascii_strip_re.sub('', '_'.join( - filename.split()))).strip('._') + filename = filename.replace(sep, " ") + filename = str(_filename_ascii_strip_re.sub("", "_".join(filename.split()))).strip( + "._" + ) # on nt a couple of special files are present in each folder. We # have to ensure that the target file is not such a filename. In # this case we prepend an underline - if os.name == 'nt' and filename and \ - filename.split('.')[0].upper() in _windows_device_files: - filename = '_' + filename + if ( + os.name == "nt" + and filename + and filename.split(".")[0].upper() in _windows_device_files + ): + filename = "_" + filename return filename @@ -376,21 +420,26 @@ def escape(s, quote=None): :param quote: ignored. """ if s is None: - return '' - elif hasattr(s, '__html__'): + return "" + elif hasattr(s, "__html__"): return text_type(s.__html__()) elif not isinstance(s, string_types): s = text_type(s) if quote is not None: from warnings import warn + warn( "The 'quote' parameter is no longer used as of version 0.9" " and will be removed in version 1.0.", DeprecationWarning, stacklevel=2, ) - s = s.replace('&', '&').replace('<', '<') \ - .replace('>', '>').replace('"', """) + s = ( + s.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + ) return s @@ -400,18 +449,20 @@ def unescape(s): :param s: the string to unescape. """ + def handle_match(m): name = m.group(1) if name in HTMLBuilder._entities: return unichr(HTMLBuilder._entities[name]) try: - if name[:2] in ('#x', '#X'): + if name[:2] in ("#x", "#X"): return unichr(int(name[2:], 16)) - elif name.startswith('#'): + elif name.startswith("#"): return unichr(int(name[1:])) except ValueError: pass - return u'' + return u"" + return _entity_re.sub(handle_match, s) @@ -436,22 +487,26 @@ def redirect(location, code=302, Response=None): unspecified. """ if Response is None: - from werkzeug.wrappers import Response + from .wrappers import Response display_location = escape(location) if isinstance(location, text_type): # Safe conversion is necessary here as we might redirect # to a broken URI scheme (for instance itms-services). - from werkzeug.urls import iri_to_uri + from .urls import iri_to_uri + location = iri_to_uri(location, safe_conversion=True) response = Response( '<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 3.2 Final//EN">\n' - '<title>Redirecting...</title>\n' - '<h1>Redirecting...</h1>\n' - '<p>You should be redirected automatically to target URL: ' - '<a href="%s">%s</a>. If not click the link.' % - (escape(location), display_location), code, mimetype='text/html') - response.headers['Location'] = location + "<title>Redirecting...</title>\n" + "<h1>Redirecting...</h1>\n" + "<p>You should be redirected automatically to target URL: " + '<a href="%s">%s</a>. If not click the link.' + % (escape(location), display_location), + code, + mimetype="text/html", + ) + response.headers["Location"] = location return response @@ -463,10 +518,10 @@ def append_slash_redirect(environ, code=301): the redirect. :param code: the status code for the redirect. """ - new_path = environ['PATH_INFO'].strip('/') + '/' - query_string = environ.get('QUERY_STRING') + new_path = environ["PATH_INFO"].strip("/") + "/" + query_string = environ.get("QUERY_STRING") if query_string: - new_path += '?' + query_string + new_path += "?" + query_string return redirect(new_path, code) @@ -486,17 +541,17 @@ def import_string(import_name, silent=False): # force the import name to automatically convert to strings # __import__ is not able to handle unicode strings in the fromlist # if the module is a package - import_name = str(import_name).replace(':', '.') + import_name = str(import_name).replace(":", ".") try: try: __import__(import_name) except ImportError: - if '.' not in import_name: + if "." not in import_name: raise else: return sys.modules[import_name] - module_name, obj_name = import_name.rsplit('.', 1) + module_name, obj_name = import_name.rsplit(".", 1) module = __import__(module_name, globals(), locals(), [obj_name]) try: return getattr(module, obj_name) @@ -506,9 +561,8 @@ def import_string(import_name, silent=False): except ImportError as e: if not silent: reraise( - ImportStringError, - ImportStringError(import_name, e), - sys.exc_info()[2]) + ImportStringError, ImportStringError(import_name, e), sys.exc_info()[2] + ) def find_modules(import_path, include_packages=False, recursive=False): @@ -527,11 +581,11 @@ def find_modules(import_path, include_packages=False, recursive=False): :return: generator """ module = import_string(import_path) - path = getattr(module, '__path__', None) + path = getattr(module, "__path__", None) if path is None: - raise ValueError('%r is not a package' % import_path) - basename = module.__name__ + '.' - for importer, modname, ispkg in pkgutil.iter_modules(path): + raise ValueError("%r is not a package" % import_path) + basename = module.__name__ + "." + for _importer, modname, ispkg in pkgutil.iter_modules(path): modname = basename + modname if ispkg: if include_packages: @@ -608,24 +662,32 @@ def bind_arguments(func, args, kwargs): :param kwargs: a dict of keyword arguments. :return: a :class:`dict` of bound keyword arguments. """ - args, kwargs, missing, extra, extra_positional, \ - arg_spec, vararg_var, kwarg_var = _parse_signature(func)(args, kwargs) + ( + args, + kwargs, + missing, + extra, + extra_positional, + arg_spec, + vararg_var, + kwarg_var, + ) = _parse_signature(func)(args, kwargs) values = {} - for (name, has_default, default), value in zip(arg_spec, args): + for (name, _has_default, _default), value in zip(arg_spec, args): values[name] = value if vararg_var is not None: values[vararg_var] = tuple(extra_positional) elif extra_positional: - raise TypeError('too many positional arguments') + raise TypeError("too many positional arguments") if kwarg_var is not None: multikw = set(extra) & set([x[0] for x in arg_spec]) if multikw: - raise TypeError('got multiple values for keyword argument ' - + repr(next(iter(multikw)))) + raise TypeError( + "got multiple values for keyword argument " + repr(next(iter(multikw))) + ) values[kwarg_var] = extra elif extra: - raise TypeError('got unexpected keyword argument ' - + repr(next(iter(extra)))) + raise TypeError("got unexpected keyword argument " + repr(next(iter(extra)))) return values @@ -637,15 +699,14 @@ class ArgumentValidationError(ValueError): self.missing = set(missing or ()) self.extra = extra or {} self.extra_positional = extra_positional or [] - ValueError.__init__(self, 'function arguments invalid. (' - '%d missing, %d additional)' % ( - len(self.missing), - len(self.extra) + len(self.extra_positional) - )) + ValueError.__init__( + self, + "function arguments invalid. (%d missing, %d additional)" + % (len(self.missing), len(self.extra) + len(self.extra_positional)), + ) class ImportStringError(ImportError): - """Provides information about a failed :func:`import_string` attempt.""" #: String in dotted notation that failed to be imported. @@ -658,47 +719,51 @@ class ImportStringError(ImportError): self.exception = exception msg = ( - 'import_string() failed for %r. Possible reasons are:\n\n' - '- missing __init__.py in a package;\n' - '- package or module path not included in sys.path;\n' - '- duplicated package or module name taking precedence in ' - 'sys.path;\n' - '- missing module, class, function or variable;\n\n' - 'Debugged import:\n\n%s\n\n' - 'Original exception:\n\n%s: %s') - - name = '' + "import_string() failed for %r. Possible reasons are:\n\n" + "- missing __init__.py in a package;\n" + "- package or module path not included in sys.path;\n" + "- duplicated package or module name taking precedence in " + "sys.path;\n" + "- missing module, class, function or variable;\n\n" + "Debugged import:\n\n%s\n\n" + "Original exception:\n\n%s: %s" + ) + + name = "" tracked = [] - for part in import_name.replace(':', '.').split('.'): - name += (name and '.') + part + for part in import_name.replace(":", ".").split("."): + name += (name and ".") + part imported = import_string(name, silent=True) if imported: - tracked.append((name, getattr(imported, '__file__', None))) + tracked.append((name, getattr(imported, "__file__", None))) else: - track = ['- %r found in %r.' % (n, i) for n, i in tracked] - track.append('- %r not found.' % name) - msg = msg % (import_name, '\n'.join(track), - exception.__class__.__name__, str(exception)) + track = ["- %r found in %r." % (n, i) for n, i in tracked] + track.append("- %r not found." % name) + msg = msg % ( + import_name, + "\n".join(track), + exception.__class__.__name__, + str(exception), + ) break ImportError.__init__(self, msg) def __repr__(self): - return '<%s(%r, %r)>' % (self.__class__.__name__, self.import_name, - self.exception) + return "<%s(%r, %r)>" % ( + self.__class__.__name__, + self.import_name, + self.exception, + ) # DEPRECATED -from werkzeug.datastructures import ( - MultiDict as _MultiDict, - CombinedMultiDict as _CombinedMultiDict, - Headers as _Headers, - EnvironHeaders as _EnvironHeaders, -) -from werkzeug.http import ( - parse_cookie as _parse_cookie, - dump_cookie as _dump_cookie, -) +from .datastructures import CombinedMultiDict as _CombinedMultiDict +from .datastructures import EnvironHeaders as _EnvironHeaders +from .datastructures import Headers as _Headers +from .datastructures import MultiDict as _MultiDict +from .http import dump_cookie as _dump_cookie +from .http import parse_cookie as _parse_cookie class MultiDict(_MultiDict): diff --git a/src/werkzeug/wrappers/accept.py b/src/werkzeug/wrappers/accept.py index 9fe50141..d0620a0a 100644 --- a/src/werkzeug/wrappers/accept.py +++ b/src/werkzeug/wrappers/accept.py @@ -17,15 +17,16 @@ class AcceptMixin(object): """List of mimetypes this client supports as :class:`~werkzeug.datastructures.MIMEAccept` object. """ - return parse_accept_header(self.environ.get('HTTP_ACCEPT'), MIMEAccept) + return parse_accept_header(self.environ.get("HTTP_ACCEPT"), MIMEAccept) @cached_property def accept_charsets(self): """List of charsets this client supports as :class:`~werkzeug.datastructures.CharsetAccept` object. """ - return parse_accept_header(self.environ.get('HTTP_ACCEPT_CHARSET'), - CharsetAccept) + return parse_accept_header( + self.environ.get("HTTP_ACCEPT_CHARSET"), CharsetAccept + ) @cached_property def accept_encodings(self): @@ -33,7 +34,7 @@ class AcceptMixin(object): are compression encodings such as gzip. For charsets have a look at :attr:`accept_charset`. """ - return parse_accept_header(self.environ.get('HTTP_ACCEPT_ENCODING')) + return parse_accept_header(self.environ.get("HTTP_ACCEPT_ENCODING")) @cached_property def accept_languages(self): @@ -44,5 +45,6 @@ class AcceptMixin(object): In previous versions this was a regular :class:`~werkzeug.datastructures.Accept` object. """ - return parse_accept_header(self.environ.get('HTTP_ACCEPT_LANGUAGE'), - LanguageAccept) + return parse_accept_header( + self.environ.get("HTTP_ACCEPT_LANGUAGE"), LanguageAccept + ) diff --git a/src/werkzeug/wrappers/auth.py b/src/werkzeug/wrappers/auth.py index 2660e9ba..714f7554 100644 --- a/src/werkzeug/wrappers/auth.py +++ b/src/werkzeug/wrappers/auth.py @@ -12,7 +12,7 @@ class AuthorizationMixin(object): @cached_property def authorization(self): """The `Authorization` object in parsed form.""" - header = self.environ.get('HTTP_AUTHORIZATION') + header = self.environ.get("HTTP_AUTHORIZATION") return parse_authorization_header(header) @@ -22,10 +22,12 @@ class WWWAuthenticateMixin(object): @property def www_authenticate(self): """The `WWW-Authenticate` header in a parsed form.""" + def on_update(www_auth): - if not www_auth and 'www-authenticate' in self.headers: - del self.headers['www-authenticate'] + if not www_auth and "www-authenticate" in self.headers: + del self.headers["www-authenticate"] elif www_auth: - self.headers['WWW-Authenticate'] = www_auth.to_header() - header = self.headers.get('www-authenticate') + self.headers["WWW-Authenticate"] = www_auth.to_header() + + header = self.headers.get("www-authenticate") return parse_www_authenticate_header(header, on_update) diff --git a/src/werkzeug/wrappers/base_request.py b/src/werkzeug/wrappers/base_request.py index c106c9c4..f5c40be9 100644 --- a/src/werkzeug/wrappers/base_request.py +++ b/src/werkzeug/wrappers/base_request.py @@ -72,10 +72,10 @@ class BaseRequest(object): """ #: the charset for the request, defaults to utf-8 - charset = 'utf-8' + charset = "utf-8" #: the error handling procedure for errors, defaults to 'replace' - encoding_errors = 'replace' + encoding_errors = "replace" #: the maximum content length. This is forwarded to the form data #: parsing function (:func:`parse_form_data`). When set and the @@ -150,7 +150,7 @@ class BaseRequest(object): def __init__(self, environ, populate_request=True, shallow=False): self.environ = environ if populate_request and not shallow: - self.environ['werkzeug.request'] = self + self.environ["werkzeug.request"] = self self.shallow = shallow def __repr__(self): @@ -160,14 +160,11 @@ class BaseRequest(object): args = [] try: args.append("'%s'" % to_native(self.url, self.url_charset)) - args.append('[%s]' % self.method) + args.append("[%s]" % self.method) except Exception: - args.append('(invalid WSGI environ)') + args.append("(invalid WSGI environ)") - return '<%s %s>' % ( - self.__class__.__name__, - ' '.join(args) - ) + return "<%s %s>" % (self.__class__.__name__, " ".join(args)) @property def url_charset(self): @@ -197,9 +194,10 @@ class BaseRequest(object): :return: request object """ - from werkzeug.test import EnvironBuilder - charset = kwargs.pop('charset', cls.charset) - kwargs['charset'] = charset + from ..test import EnvironBuilder + + charset = kwargs.pop("charset", cls.charset) + kwargs["charset"] = charset builder = EnvironBuilder(*args, **kwargs) try: return builder.get_request(cls) @@ -228,7 +226,7 @@ class BaseRequest(object): #: the request. The return value is then called with the latest #: two arguments. This makes it possible to use this decorator for #: both methods and standalone WSGI functions. - from werkzeug.exceptions import HTTPException + from ..exceptions import HTTPException def application(*args): request = cls(args[-2]) @@ -241,8 +239,9 @@ class BaseRequest(object): return update_wrapper(application, f) - def _get_file_stream(self, total_content_length, content_type, filename=None, - content_length=None): + def _get_file_stream( + self, total_content_length, content_type, filename=None, content_length=None + ): """Called to get a stream for the file upload. This must provide a file-like class with `read()`, `readline()` @@ -266,7 +265,8 @@ class BaseRequest(object): total_content_length=total_content_length, content_type=content_type, filename=filename, - content_length=content_length) + content_length=content_length, + ) @property def want_form_data_parsed(self): @@ -275,7 +275,7 @@ class BaseRequest(object): .. versionadded:: 0.8 """ - return bool(self.environ.get('CONTENT_TYPE')) + return bool(self.environ.get("CONTENT_TYPE")) def make_form_data_parser(self): """Creates the form data parser. Instantiates the @@ -283,12 +283,14 @@ class BaseRequest(object): .. versionadded:: 0.8 """ - return self.form_data_parser_class(self._get_file_stream, - self.charset, - self.encoding_errors, - self.max_form_memory_size, - self.max_content_length, - self.parameter_storage_class) + return self.form_data_parser_class( + self._get_file_stream, + self.charset, + self.encoding_errors, + self.max_form_memory_size, + self.max_content_length, + self.parameter_storage_class, + ) def _load_form_data(self): """Method used internally to retrieve submitted data. After calling @@ -300,26 +302,30 @@ class BaseRequest(object): .. versionadded:: 0.8 """ # abort early if we have already consumed the stream - if 'form' in self.__dict__: + if "form" in self.__dict__: return _assert_not_shallow(self) if self.want_form_data_parsed: - content_type = self.environ.get('CONTENT_TYPE', '') + content_type = self.environ.get("CONTENT_TYPE", "") content_length = get_content_length(self.environ) mimetype, options = parse_options_header(content_type) parser = self.make_form_data_parser() - data = parser.parse(self._get_stream_for_parsing(), - mimetype, content_length, options) + data = parser.parse( + self._get_stream_for_parsing(), mimetype, content_length, options + ) else: - data = (self.stream, self.parameter_storage_class(), - self.parameter_storage_class()) + data = ( + self.stream, + self.parameter_storage_class(), + self.parameter_storage_class(), + ) # inject the values into the instance dict so that we bypass # our cached_property non-data descriptor. d = self.__dict__ - d['stream'], d['form'], d['files'] = data + d["stream"], d["form"], d["files"] = data def _get_stream_for_parsing(self): """This is the same as accessing :attr:`stream` with the difference @@ -328,7 +334,7 @@ class BaseRequest(object): .. versionadded:: 0.9.3 """ - cached_data = getattr(self, '_cached_data', None) + cached_data = getattr(self, "_cached_data", None) if cached_data is not None: return BytesIO(cached_data) return self.stream @@ -340,8 +346,8 @@ class BaseRequest(object): .. versionadded:: 0.9 """ - files = self.__dict__.get('files') - for key, value in iter_multi_items(files or ()): + files = self.__dict__.get("files") + for _key, value in iter_multi_items(files or ()): value.close() def __enter__(self): @@ -371,12 +377,14 @@ class BaseRequest(object): _assert_not_shallow(self) return get_input_stream(self.environ) - input_stream = environ_property('wsgi.input', """ - The WSGI input stream. + input_stream = environ_property( + "wsgi.input", + """The WSGI input stream. - In general it's a bad idea to use this one because you can easily read past - the boundary. Use the :attr:`stream` instead. - """) + In general it's a bad idea to use this one because you can + easily read past the boundary. Use the :attr:`stream` + instead.""", + ) @cached_property def args(self): @@ -389,9 +397,12 @@ class BaseRequest(object): :attr:`parameter_storage_class` to a different type. This might be necessary if the order of the form data is important. """ - return url_decode(wsgi_get_bytes(self.environ.get('QUERY_STRING', '')), - self.url_charset, errors=self.encoding_errors, - cls=self.parameter_storage_class) + return url_decode( + wsgi_get_bytes(self.environ.get("QUERY_STRING", "")), + self.url_charset, + errors=self.encoding_errors, + cls=self.parameter_storage_class, + ) @cached_property def data(self): @@ -401,7 +412,7 @@ class BaseRequest(object): """ if self.disable_data_descriptor: - raise AttributeError('data descriptor is disabled') + raise AttributeError("data descriptor is disabled") # XXX: this should eventually be deprecated. # We trigger form data parsing first which means that the descriptor @@ -436,7 +447,7 @@ class BaseRequest(object): .. versionadded:: 0.9 """ - rv = getattr(self, '_cached_data', None) + rv = getattr(self, "_cached_data", None) if rv is None: if parse_form_data: self._load_form_data() @@ -504,9 +515,12 @@ class BaseRequest(object): def cookies(self): """A :class:`dict` with the contents of all cookies transmitted with the request.""" - return parse_cookie(self.environ, self.charset, - self.encoding_errors, - cls=self.dict_storage_class) + return parse_cookie( + self.environ, + self.charset, + self.encoding_errors, + cls=self.dict_storage_class, + ) @cached_property def headers(self): @@ -521,37 +535,39 @@ class BaseRequest(object): info in the WSGI environment but will always include a leading slash, even if the URL root is accessed. """ - raw_path = wsgi_decoding_dance(self.environ.get('PATH_INFO') or '', - self.charset, self.encoding_errors) - return '/' + raw_path.lstrip('/') + raw_path = wsgi_decoding_dance( + self.environ.get("PATH_INFO") or "", self.charset, self.encoding_errors + ) + return "/" + raw_path.lstrip("/") @cached_property def full_path(self): """Requested path as unicode, including the query string.""" - return self.path + u'?' + to_unicode(self.query_string, self.url_charset) + return self.path + u"?" + to_unicode(self.query_string, self.url_charset) @cached_property def script_root(self): """The root path of the script without the trailing slash.""" - raw_path = wsgi_decoding_dance(self.environ.get('SCRIPT_NAME') or '', - self.charset, self.encoding_errors) - return raw_path.rstrip('/') + raw_path = wsgi_decoding_dance( + self.environ.get("SCRIPT_NAME") or "", self.charset, self.encoding_errors + ) + return raw_path.rstrip("/") @cached_property def url(self): """The reconstructed current URL as IRI. See also: :attr:`trusted_hosts`. """ - return get_current_url(self.environ, - trusted_hosts=self.trusted_hosts) + return get_current_url(self.environ, trusted_hosts=self.trusted_hosts) @cached_property def base_url(self): """Like :attr:`url` but without the querystring See also: :attr:`trusted_hosts`. """ - return get_current_url(self.environ, strip_querystring=True, - trusted_hosts=self.trusted_hosts) + return get_current_url( + self.environ, strip_querystring=True, trusted_hosts=self.trusted_hosts + ) @cached_property def url_root(self): @@ -559,16 +575,16 @@ class BaseRequest(object): root as IRI. See also: :attr:`trusted_hosts`. """ - return get_current_url(self.environ, True, - trusted_hosts=self.trusted_hosts) + return get_current_url(self.environ, True, trusted_hosts=self.trusted_hosts) @cached_property def host_url(self): """Just the host with scheme as IRI. See also: :attr:`trusted_hosts`. """ - return get_current_url(self.environ, host_only=True, - trusted_hosts=self.trusted_hosts) + return get_current_url( + self.environ, host_only=True, trusted_hosts=self.trusted_hosts + ) @cached_property def host(self): @@ -578,39 +594,51 @@ class BaseRequest(object): return get_host(self.environ, trusted_hosts=self.trusted_hosts) query_string = environ_property( - 'QUERY_STRING', '', read_only=True, - load_func=wsgi_get_bytes, doc='The URL parameters as raw bytestring.') + "QUERY_STRING", + "", + read_only=True, + load_func=wsgi_get_bytes, + doc="The URL parameters as raw bytestring.", + ) method = environ_property( - 'REQUEST_METHOD', 'GET', read_only=True, + "REQUEST_METHOD", + "GET", + read_only=True, load_func=lambda x: x.upper(), - doc="The request method. (For example ``'GET'`` or ``'POST'``).") + doc="The request method. (For example ``'GET'`` or ``'POST'``).", + ) @cached_property def access_route(self): """If a forwarded header exists this is a list of all ip addresses from the client ip to the last proxy server. """ - if 'HTTP_X_FORWARDED_FOR' in self.environ: - addr = self.environ['HTTP_X_FORWARDED_FOR'].split(',') + if "HTTP_X_FORWARDED_FOR" in self.environ: + addr = self.environ["HTTP_X_FORWARDED_FOR"].split(",") return self.list_storage_class([x.strip() for x in addr]) - elif 'REMOTE_ADDR' in self.environ: - return self.list_storage_class([self.environ['REMOTE_ADDR']]) + elif "REMOTE_ADDR" in self.environ: + return self.list_storage_class([self.environ["REMOTE_ADDR"]]) return self.list_storage_class() @property def remote_addr(self): """The remote address of the client.""" - return self.environ.get('REMOTE_ADDR') - - remote_user = environ_property('REMOTE_USER', doc=''' - If the server supports user authentication, and the script is - protected, this attribute contains the username the user has - authenticated as.''') - - scheme = environ_property('wsgi.url_scheme', doc=''' + return self.environ.get("REMOTE_ADDR") + + remote_user = environ_property( + "REMOTE_USER", + doc="""If the server supports user authentication, and the + script is protected, this attribute contains the username the + user has authenticated as.""", + ) + + scheme = environ_property( + "wsgi.url_scheme", + doc=""" URL scheme (http or https). - .. versionadded:: 0.7''') + .. versionadded:: 0.7""", + ) @property def is_xhr(self): @@ -630,28 +658,36 @@ class BaseRequest(object): " is not standard and is unreliable. You may be able to use" " 'accept_mimetypes' instead.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) - return self.environ.get( - 'HTTP_X_REQUESTED_WITH', '' - ).lower() == 'xmlhttprequest' - - is_secure = property(lambda self: self.environ['wsgi.url_scheme'] == 'https', - doc='`True` if the request is secure.') - is_multithread = environ_property('wsgi.multithread', doc=''' - boolean that is `True` if the application is served by - a multithreaded WSGI server.''') - is_multiprocess = environ_property('wsgi.multiprocess', doc=''' - boolean that is `True` if the application is served by - a WSGI server that spawns multiple processes.''') - is_run_once = environ_property('wsgi.run_once', doc=''' - boolean that is `True` if the application will be executed only - once in a process lifetime. This is the case for CGI for example, - but it's not guaranteed that the execution only happens one time.''') + return self.environ.get("HTTP_X_REQUESTED_WITH", "").lower() == "xmlhttprequest" + + is_secure = property( + lambda self: self.environ["wsgi.url_scheme"] == "https", + doc="`True` if the request is secure.", + ) + is_multithread = environ_property( + "wsgi.multithread", + doc="""boolean that is `True` if the application is served by a + multithreaded WSGI server.""", + ) + is_multiprocess = environ_property( + "wsgi.multiprocess", + doc="""boolean that is `True` if the application is served by a + WSGI server that spawns multiple processes.""", + ) + is_run_once = environ_property( + "wsgi.run_once", + doc="""boolean that is `True` if the application will be + executed only once in a process lifetime. This is the case for + CGI for example, but it's not guaranteed that the execution only + happens one time.""", + ) def _assert_not_shallow(request): if request.shallow: - raise RuntimeError('A shallow request tried to consume ' - 'form data. If you really want to do ' - 'that, set `shallow` to False.') + raise RuntimeError( + "A shallow request tried to consume form data. If you really" + " want to do that, set `shallow` to False." + ) diff --git a/src/werkzeug/wrappers/base_response.py b/src/werkzeug/wrappers/base_response.py index d7a8763b..d944a7d2 100644 --- a/src/werkzeug/wrappers/base_response.py +++ b/src/werkzeug/wrappers/base_response.py @@ -22,6 +22,7 @@ def _run_wsgi_app(*args): """ global _run_wsgi_app from ..test import run_wsgi_app as _run_wsgi_app + return _run_wsgi_app(*args) @@ -36,7 +37,7 @@ def _warn_if_string(iterable): " client one character at a time. This is almost never" " intended behavior, use 'response.data' to assign strings" " to the response object.", - stacklevel=2 + stacklevel=2, ) @@ -129,13 +130,13 @@ class BaseResponse(object): """ #: the charset of the response. - charset = 'utf-8' + charset = "utf-8" #: the default status if none is provided. default_status = 200 #: the default mimetype if none is provided. - default_mimetype = 'text/plain' + default_mimetype = "text/plain" #: if set to `False` accessing properties on the response object will #: not try to consume the response iterator and convert it into a list. @@ -169,8 +170,15 @@ class BaseResponse(object): #: .. _`cookie`: http://browsercookielimits.squawky.net/ max_cookie_size = 4093 - def __init__(self, response=None, status=None, headers=None, - mimetype=None, content_type=None, direct_passthrough=False): + def __init__( + self, + response=None, + status=None, + headers=None, + mimetype=None, + content_type=None, + direct_passthrough=False, + ): if isinstance(headers, Headers): self.headers = headers elif not headers: @@ -179,13 +187,13 @@ class BaseResponse(object): self.headers = Headers(headers) if content_type is None: - if mimetype is None and 'content-type' not in self.headers: + if mimetype is None and "content-type" not in self.headers: mimetype = self.default_mimetype if mimetype is not None: mimetype = get_content_type(mimetype, self.charset) content_type = mimetype if content_type is not None: - self.headers['Content-Type'] = content_type + self.headers["Content-Type"] = content_type if status is None: status = self.default_status if isinstance(status, integer_types): @@ -218,14 +226,10 @@ class BaseResponse(object): def __repr__(self): if self.is_sequence: - body_info = '%d bytes' % sum(map(len, self.iter_encoded())) + body_info = "%d bytes" % sum(map(len, self.iter_encoded())) else: - body_info = 'streamed' if self.is_streamed else 'likely-streamed' - return '<%s %s [%s]>' % ( - self.__class__.__name__, - body_info, - self.status - ) + body_info = "streamed" if self.is_streamed else "likely-streamed" + return "<%s %s [%s]>" % (self.__class__.__name__, body_info, self.status) @classmethod def force_type(cls, response, environ=None): @@ -258,8 +262,10 @@ class BaseResponse(object): """ if not isinstance(response, BaseResponse): if environ is None: - raise TypeError('cannot convert WSGI application into ' - 'response objects without an environ') + raise TypeError( + "cannot convert WSGI application into response" + " objects without an environ" + ) response = BaseResponse(*_run_wsgi_app(response, environ)) response.__class__ = cls return response @@ -286,11 +292,13 @@ class BaseResponse(object): def _set_status_code(self, code): self._status_code = code try: - self._status = '%d %s' % (code, HTTP_STATUS_CODES[code].upper()) + self._status = "%d %s" % (code, HTTP_STATUS_CODES[code].upper()) except KeyError: - self._status = '%d UNKNOWN' % code - status_code = property(_get_status_code, _set_status_code, - doc='The HTTP Status code as number') + self._status = "%d UNKNOWN" % code + + status_code = property( + _get_status_code, _set_status_code, doc="The HTTP Status code as number" + ) del _get_status_code, _set_status_code def _get_status(self): @@ -300,16 +308,17 @@ class BaseResponse(object): try: self._status = to_native(value) except AttributeError: - raise TypeError('Invalid status argument') + raise TypeError("Invalid status argument") try: self._status_code = int(self._status.split(None, 1)[0]) except ValueError: self._status_code = 0 - self._status = '0 %s' % self._status + self._status = "0 %s" % self._status except IndexError: - raise ValueError('Empty status argument') - status = property(_get_status, _set_status, doc='The HTTP Status code') + raise ValueError("Empty status argument") + + status = property(_get_status, _set_status, doc="The HTTP Status code") del _get_status, _set_status def get_data(self, as_text=False): @@ -326,7 +335,7 @@ class BaseResponse(object): .. versionadded:: 0.9 """ self._ensure_sequence() - rv = b''.join(self.iter_encoded()) + rv = b"".join(self.iter_encoded()) if as_text: rv = rv.decode(self.charset) return rv @@ -346,12 +355,12 @@ class BaseResponse(object): value = bytes(value) self.response = [value] if self.automatically_set_content_length: - self.headers['Content-Length'] = str(len(value)) + self.headers["Content-Length"] = str(len(value)) data = property( get_data, set_data, - doc='A descriptor that calls :meth:`get_data` and :meth:`set_data`.' + doc="A descriptor that calls :meth:`get_data` and :meth:`set_data`.", ) def calculate_content_length(self): @@ -375,14 +384,16 @@ class BaseResponse(object): self.response = list(self.response) return if self.direct_passthrough: - raise RuntimeError('Attempted implicit sequence conversion ' - 'but the response object is in direct ' - 'passthrough mode.') + raise RuntimeError( + "Attempted implicit sequence conversion but the" + " response object is in direct passthrough mode." + ) if not self.implicit_sequence_conversion: - raise RuntimeError('The response object required the iterable ' - 'to be a sequence, but the implicit ' - 'conversion was disabled. Call ' - 'make_sequence() yourself.') + raise RuntimeError( + "The response object required the iterable to be a" + " sequence, but the implicit conversion was disabled." + " Call make_sequence() yourself." + ) self.make_sequence() def make_sequence(self): @@ -397,7 +408,7 @@ class BaseResponse(object): # if we consume an iterable we have to ensure that the close # method of the iterable is called if available when we tear # down the response - close = getattr(self.response, 'close', None) + close = getattr(self.response, "close", None) self.response = list(self.iter_encoded()) if close is not None: self.call_on_close(close) @@ -415,9 +426,18 @@ class BaseResponse(object): # value from get_app_iter or iter_encoded. return _iter_encoded(self.response, self.charset) - def set_cookie(self, key, value='', max_age=None, expires=None, - path='/', domain=None, secure=False, httponly=False, - samesite=None): + def set_cookie( + self, + key, + value="", + max_age=None, + expires=None, + path="/", + domain=None, + secure=False, + httponly=False, + samesite=None, + ): """Sets a cookie. The parameters are the same as in the cookie `Morsel` object in the Python standard library but it accepts unicode data, too. @@ -445,21 +465,24 @@ class BaseResponse(object): be attached to requests if those requests are "same-site". """ - self.headers.add('Set-Cookie', dump_cookie( - key, - value=value, - max_age=max_age, - expires=expires, - path=path, - domain=domain, - secure=secure, - httponly=httponly, - charset=self.charset, - max_size=self.max_cookie_size, - samesite=samesite - )) - - def delete_cookie(self, key, path='/', domain=None): + self.headers.add( + "Set-Cookie", + dump_cookie( + key, + value=value, + max_age=max_age, + expires=expires, + path=path, + domain=domain, + secure=secure, + httponly=httponly, + charset=self.charset, + max_size=self.max_cookie_size, + samesite=samesite, + ), + ) + + def delete_cookie(self, key, path="/", domain=None): """Delete a cookie. Fails silently if key doesn't exist. :param key: the key (name) of the cookie to be deleted. @@ -503,7 +526,7 @@ class BaseResponse(object): .. versionadded:: 0.9 Can now be used in a with statement. """ - if hasattr(self.response, 'close'): + if hasattr(self.response, "close"): self.response.close() for func in self._on_close: func() @@ -525,7 +548,7 @@ class BaseResponse(object): # we explicitly set the length to a list of the *encoded* response # iterator. Even if the implicit sequence conversion is disabled. self.response = list(self.iter_encoded()) - self.headers['Content-Length'] = str(sum(map(len, self.response))) + self.headers["Content-Length"] = str(sum(map(len, self.response))) def get_wsgi_headers(self, environ): """This is automatically called right before the response is started @@ -562,11 +585,11 @@ class BaseResponse(object): # speedup. for key, value in headers: ikey = key.lower() - if ikey == u'location': + if ikey == u"location": location = value - elif ikey == u'content-location': + elif ikey == u"content-location": content_location = value - elif ikey == u'content-length': + elif ikey == u"content-length": content_length = value # make sure the location header is an absolute URL @@ -583,17 +606,17 @@ class BaseResponse(object): current_url = iri_to_uri(current_url) location = url_join(current_url, location) if location != old_location: - headers['Location'] = location + headers["Location"] = location # make sure the content location is a URL - if content_location is not None and \ - isinstance(content_location, text_type): - headers['Content-Location'] = iri_to_uri(content_location) + if content_location is not None and isinstance(content_location, text_type): + headers["Content-Location"] = iri_to_uri(content_location) if 100 <= status < 200 or status == 204: - # Per section 3.3.2 of RFC 7230, "a server MUST NOT send a Content-Length header field - # in any response with a status code of 1xx (Informational) or 204 (No Content)." - headers.remove('Content-Length') + # Per section 3.3.2 of RFC 7230, "a server MUST NOT send a + # Content-Length header field in any response with a status + # code of 1xx (Informational) or 204 (No Content)." + headers.remove("Content-Length") elif status == 304: remove_entity_headers(headers) @@ -602,19 +625,21 @@ class BaseResponse(object): # flattening the iterator or encoding of unicode strings in # the response. We however should not do that if we have a 304 # response. - if self.automatically_set_content_length and \ - self.is_sequence and content_length is None and \ - status not in (204, 304) and \ - not (100 <= status < 200): + if ( + self.automatically_set_content_length + and self.is_sequence + and content_length is None + and status not in (204, 304) + and not (100 <= status < 200) + ): try: - content_length = sum(len(to_bytes(x, 'ascii')) - for x in self.response) + content_length = sum(len(to_bytes(x, "ascii")) for x in self.response) except UnicodeError: # aha, something non-bytestringy in there, too bad, we # can't safely figure out the length of the response. pass else: - headers['Content-Length'] = str(content_length) + headers["Content-Length"] = str(content_length) return headers @@ -633,8 +658,11 @@ class BaseResponse(object): :return: a response iterable. """ status = self.status_code - if environ['REQUEST_METHOD'] == 'HEAD' or \ - 100 <= status < 200 or status in (204, 304): + if ( + environ["REQUEST_METHOD"] == "HEAD" + or 100 <= status < 200 + or status in (204, 304) + ): iterable = () elif self.direct_passthrough: if __debug__: diff --git a/src/werkzeug/wrappers/common_descriptors.py b/src/werkzeug/wrappers/common_descriptors.py index 3ad0474a..e4107ee0 100644 --- a/src/werkzeug/wrappers/common_descriptors.py +++ b/src/werkzeug/wrappers/common_descriptors.py @@ -26,11 +26,13 @@ class CommonRequestDescriptorsMixin(object): .. versionadded:: 0.5 """ - content_type = environ_property('CONTENT_TYPE', doc=''' - The Content-Type entity-header field indicates the media type of - the entity-body sent to the recipient or, in the case of the HEAD - method, the media type that would have been sent had the request - been a GET.''') + content_type = environ_property( + "CONTENT_TYPE", + doc="""The Content-Type entity-header field indicates the media + type of the entity-body sent to the recipient or, in the case of + the HEAD method, the media type that would have been sent had + the request been a GET.""", + ) @cached_property def content_length(self): @@ -41,40 +43,58 @@ class CommonRequestDescriptorsMixin(object): """ return get_content_length(self.environ) - content_encoding = environ_property('HTTP_CONTENT_ENCODING', doc=''' - The Content-Encoding entity-header field is used as a modifier to the - media-type. When present, its value indicates what additional content - codings have been applied to the entity-body, and thus what decoding - mechanisms must be applied in order to obtain the media-type - referenced by the Content-Type header field. - - .. versionadded:: 0.9''') - content_md5 = environ_property('HTTP_CONTENT_MD5', doc=''' - The Content-MD5 entity-header field, as defined in RFC 1864, is an - MD5 digest of the entity-body for the purpose of providing an - end-to-end message integrity check (MIC) of the entity-body. (Note: - a MIC is good for detecting accidental modification of the - entity-body in transit, but is not proof against malicious attacks.) - - .. versionadded:: 0.9''') - referrer = environ_property('HTTP_REFERER', doc=''' - The Referer[sic] request-header field allows the client to specify, - for the server's benefit, the address (URI) of the resource from which - the Request-URI was obtained (the "referrer", although the header - field is misspelled).''') - date = environ_property('HTTP_DATE', None, parse_date, doc=''' - The Date general-header field represents the date and time at which - the message was originated, having the same semantics as orig-date - in RFC 822.''') - max_forwards = environ_property('HTTP_MAX_FORWARDS', None, int, doc=''' - The Max-Forwards request-header field provides a mechanism with the - TRACE and OPTIONS methods to limit the number of proxies or gateways - that can forward the request to the next inbound server.''') + content_encoding = environ_property( + "HTTP_CONTENT_ENCODING", + doc="""The Content-Encoding entity-header field is used as a + modifier to the media-type. When present, its value indicates + what additional content codings have been applied to the + entity-body, and thus what decoding mechanisms must be applied + in order to obtain the media-type referenced by the Content-Type + header field. + + .. versionadded:: 0.9""", + ) + content_md5 = environ_property( + "HTTP_CONTENT_MD5", + doc="""The Content-MD5 entity-header field, as defined in + RFC 1864, is an MD5 digest of the entity-body for the purpose of + providing an end-to-end message integrity check (MIC) of the + entity-body. (Note: a MIC is good for detecting accidental + modification of the entity-body in transit, but is not proof + against malicious attacks.) + + .. versionadded:: 0.9""", + ) + referrer = environ_property( + "HTTP_REFERER", + doc="""The Referer[sic] request-header field allows the client + to specify, for the server's benefit, the address (URI) of the + resource from which the Request-URI was obtained (the + "referrer", although the header field is misspelled).""", + ) + date = environ_property( + "HTTP_DATE", + None, + parse_date, + doc="""The Date general-header field represents the date and + time at which the message was originated, having the same + semantics as orig-date in RFC 822.""", + ) + max_forwards = environ_property( + "HTTP_MAX_FORWARDS", + None, + int, + doc="""The Max-Forwards request-header field provides a + mechanism with the TRACE and OPTIONS methods to limit the number + of proxies or gateways that can forward the request to the next + inbound server.""", + ) def _parse_content_type(self): - if not hasattr(self, '_parsed_content_type'): - self._parsed_content_type = \ - parse_options_header(self.environ.get('CONTENT_TYPE', '')) + if not hasattr(self, "_parsed_content_type"): + self._parsed_content_type = parse_options_header( + self.environ.get("CONTENT_TYPE", "") + ) @property def mimetype(self): @@ -103,7 +123,7 @@ class CommonRequestDescriptorsMixin(object): optional behavior from the viewpoint of the protocol; however, some systems MAY require that behavior be consistent with the directives. """ - return parse_set_header(self.environ.get('HTTP_PRAGMA', '')) + return parse_set_header(self.environ.get("HTTP_PRAGMA", "")) class CommonResponseDescriptorsMixin(object): @@ -112,115 +132,157 @@ class CommonResponseDescriptorsMixin(object): HTTP headers with automatic type conversion. """ - def _get_mimetype(self): - ct = self.headers.get('content-type') + @property + def mimetype(self): + """The mimetype (content type without charset etc.)""" + ct = self.headers.get("content-type") if ct: - return ct.split(';')[0].strip() + return ct.split(";")[0].strip() + + @mimetype.setter + def mimetype(self, value): + self.headers["Content-Type"] = get_content_type(value, self.charset) + + @property + def mimetype_params(self): + """The mimetype parameters as dict. For example if the + content type is ``text/html; charset=utf-8`` the params would be + ``{'charset': 'utf-8'}``. - def _set_mimetype(self, value): - self.headers['Content-Type'] = get_content_type(value, self.charset) + .. versionadded:: 0.5 + """ - def _get_mimetype_params(self): def on_update(d): - self.headers['Content-Type'] = \ - dump_options_header(self.mimetype, d) - d = parse_options_header(self.headers.get('content-type', ''))[1] + self.headers["Content-Type"] = dump_options_header(self.mimetype, d) + + d = parse_options_header(self.headers.get("content-type", ""))[1] return CallbackDict(d, on_update) - mimetype = property(_get_mimetype, _set_mimetype, doc=''' - The mimetype (content type without charset etc.)''') - mimetype_params = property(_get_mimetype_params, doc=''' - The mimetype parameters as dict. For example if the content - type is ``text/html; charset=utf-8`` the params would be - ``{'charset': 'utf-8'}``. + location = header_property( + "Location", + doc="""The Location response-header field is used to redirect + the recipient to a location other than the Request-URI for + completion of the request or identification of a new + resource.""", + ) + age = header_property( + "Age", + None, + parse_age, + dump_age, + doc="""The Age response-header field conveys the sender's + estimate of the amount of time since the response (or its + revalidation) was generated at the origin server. - .. versionadded:: 0.5 - ''') - location = header_property('Location', doc=''' - The Location response-header field is used to redirect the recipient - to a location other than the Request-URI for completion of the request - or identification of a new resource.''') - age = header_property('Age', None, parse_age, dump_age, doc=''' - The Age response-header field conveys the sender's estimate of the - amount of time since the response (or its revalidation) was - generated at the origin server. - - Age values are non-negative decimal integers, representing time in - seconds.''') - content_type = header_property('Content-Type', doc=''' - The Content-Type entity-header field indicates the media type of the - entity-body sent to the recipient or, in the case of the HEAD method, - the media type that would have been sent had the request been a GET. - ''') - content_length = header_property('Content-Length', None, int, str, doc=''' - The Content-Length entity-header field indicates the size of the - entity-body, in decimal number of OCTETs, sent to the recipient or, - in the case of the HEAD method, the size of the entity-body that would - have been sent had the request been a GET.''') - content_location = header_property('Content-Location', doc=''' - The Content-Location entity-header field MAY be used to supply the - resource location for the entity enclosed in the message when that - entity is accessible from a location separate from the requested - resource's URI.''') - content_encoding = header_property('Content-Encoding', doc=''' - The Content-Encoding entity-header field is used as a modifier to the - media-type. When present, its value indicates what additional content - codings have been applied to the entity-body, and thus what decoding - mechanisms must be applied in order to obtain the media-type - referenced by the Content-Type header field.''') - content_md5 = header_property('Content-MD5', doc=''' - The Content-MD5 entity-header field, as defined in RFC 1864, is an - MD5 digest of the entity-body for the purpose of providing an - end-to-end message integrity check (MIC) of the entity-body. (Note: - a MIC is good for detecting accidental modification of the - entity-body in transit, but is not proof against malicious attacks.) - ''') - date = header_property('Date', None, parse_date, http_date, doc=''' - The Date general-header field represents the date and time at which - the message was originated, having the same semantics as orig-date - in RFC 822.''') - expires = header_property('Expires', None, parse_date, http_date, doc=''' - The Expires entity-header field gives the date/time after which the - response is considered stale. A stale cache entry may not normally be - returned by a cache.''') - last_modified = header_property('Last-Modified', None, parse_date, - http_date, doc=''' - The Last-Modified entity-header field indicates the date and time at - which the origin server believes the variant was last modified.''') - - def _get_retry_after(self): - value = self.headers.get('retry-after') + Age values are non-negative decimal integers, representing time + in seconds.""", + ) + content_type = header_property( + "Content-Type", + doc="""The Content-Type entity-header field indicates the media + type of the entity-body sent to the recipient or, in the case of + the HEAD method, the media type that would have been sent had + the request been a GET.""", + ) + content_length = header_property( + "Content-Length", + None, + int, + str, + doc="""The Content-Length entity-header field indicates the size + of the entity-body, in decimal number of OCTETs, sent to the + recipient or, in the case of the HEAD method, the size of the + entity-body that would have been sent had the request been a + GET.""", + ) + content_location = header_property( + "Content-Location", + doc="""The Content-Location entity-header field MAY be used to + supply the resource location for the entity enclosed in the + message when that entity is accessible from a location separate + from the requested resource's URI.""", + ) + content_encoding = header_property( + "Content-Encoding", + doc="""The Content-Encoding entity-header field is used as a + modifier to the media-type. When present, its value indicates + what additional content codings have been applied to the + entity-body, and thus what decoding mechanisms must be applied + in order to obtain the media-type referenced by the Content-Type + header field.""", + ) + content_md5 = header_property( + "Content-MD5", + doc="""The Content-MD5 entity-header field, as defined in + RFC 1864, is an MD5 digest of the entity-body for the purpose of + providing an end-to-end message integrity check (MIC) of the + entity-body. (Note: a MIC is good for detecting accidental + modification of the entity-body in transit, but is not proof + against malicious attacks.)""", + ) + date = header_property( + "Date", + None, + parse_date, + http_date, + doc="""The Date general-header field represents the date and + time at which the message was originated, having the same + semantics as orig-date in RFC 822.""", + ) + expires = header_property( + "Expires", + None, + parse_date, + http_date, + doc="""The Expires entity-header field gives the date/time after + which the response is considered stale. A stale cache entry may + not normally be returned by a cache.""", + ) + last_modified = header_property( + "Last-Modified", + None, + parse_date, + http_date, + doc="""The Last-Modified entity-header field indicates the date + and time at which the origin server believes the variant was + last modified.""", + ) + + @property + def retry_after(self): + """The Retry-After response-header field can be used with a + 503 (Service Unavailable) response to indicate how long the + service is expected to be unavailable to the requesting client. + + Time in seconds until expiration or date. + """ + value = self.headers.get("retry-after") if value is None: return elif value.isdigit(): return datetime.utcnow() + timedelta(seconds=int(value)) return parse_date(value) - def _set_retry_after(self, value): + @retry_after.setter + def retry_after(self, value): if value is None: - if 'retry-after' in self.headers: - del self.headers['retry-after'] + if "retry-after" in self.headers: + del self.headers["retry-after"] return elif isinstance(value, datetime): value = http_date(value) else: value = str(value) - self.headers['Retry-After'] = value - - retry_after = property(_get_retry_after, _set_retry_after, doc=''' - The Retry-After response-header field can be used with a 503 (Service - Unavailable) response to indicate how long the service is expected - to be unavailable to the requesting client. - - Time in seconds until expiration or date.''') + self.headers["Retry-After"] = value - def _set_property(name, doc=None): + def _set_property(name, doc=None): # noqa: B902 def fget(self): def on_update(header_set): if not header_set and name in self.headers: del self.headers[name] elif header_set: self.headers[name] = header_set.to_header() + return parse_set_header(self.headers.get(name), on_update) def fset(self, value): @@ -230,24 +292,31 @@ class CommonResponseDescriptorsMixin(object): self.headers[name] = value else: self.headers[name] = dump_header(value) + return property(fget, fset, doc=doc) - vary = _set_property('Vary', doc=''' - The Vary field value indicates the set of request-header fields that - fully determines, while the response is fresh, whether a cache is - permitted to use the response to reply to a subsequent request - without revalidation.''') - content_language = _set_property('Content-Language', doc=''' - The Content-Language entity-header field describes the natural - language(s) of the intended audience for the enclosed entity. Note - that this might not be equivalent to all the languages used within - the entity-body.''') - allow = _set_property('Allow', doc=''' - The Allow entity-header field lists the set of methods supported - by the resource identified by the Request-URI. The purpose of this - field is strictly to inform the recipient of valid methods - associated with the resource. An Allow header field MUST be - present in a 405 (Method Not Allowed) response.''') - - del _set_property, _get_mimetype, _set_mimetype, _get_retry_after, \ - _set_retry_after + vary = _set_property( + "Vary", + doc="""The Vary field value indicates the set of request-header + fields that fully determines, while the response is fresh, + whether a cache is permitted to use the response to reply to a + subsequent request without revalidation.""", + ) + content_language = _set_property( + "Content-Language", + doc="""The Content-Language entity-header field describes the + natural language(s) of the intended audience for the enclosed + entity. Note that this might not be equivalent to all the + languages used within the entity-body.""", + ) + allow = _set_property( + "Allow", + doc="""The Allow entity-header field lists the set of methods + supported by the resource identified by the Request-URI. The + purpose of this field is strictly to inform the recipient of + valid methods associated with the resource. An Allow header + field MUST be present in a 405 (Method Not Allowed) + response.""", + ) + + del _set_property diff --git a/src/werkzeug/wrappers/etag.py b/src/werkzeug/wrappers/etag.py index 7ad12f61..0733506f 100644 --- a/src/werkzeug/wrappers/etag.py +++ b/src/werkzeug/wrappers/etag.py @@ -31,9 +31,8 @@ class ETagRequestMixin(object): """A :class:`~werkzeug.datastructures.RequestCacheControl` object for the incoming cache control headers. """ - cache_control = self.environ.get('HTTP_CACHE_CONTROL') - return parse_cache_control_header(cache_control, None, - RequestCacheControl) + cache_control = self.environ.get("HTTP_CACHE_CONTROL") + return parse_cache_control_header(cache_control, None, RequestCacheControl) @cached_property def if_match(self): @@ -41,7 +40,7 @@ class ETagRequestMixin(object): :rtype: :class:`~werkzeug.datastructures.ETags` """ - return parse_etags(self.environ.get('HTTP_IF_MATCH')) + return parse_etags(self.environ.get("HTTP_IF_MATCH")) @cached_property def if_none_match(self): @@ -49,17 +48,17 @@ class ETagRequestMixin(object): :rtype: :class:`~werkzeug.datastructures.ETags` """ - return parse_etags(self.environ.get('HTTP_IF_NONE_MATCH')) + return parse_etags(self.environ.get("HTTP_IF_NONE_MATCH")) @cached_property def if_modified_since(self): """The parsed `If-Modified-Since` header as datetime object.""" - return parse_date(self.environ.get('HTTP_IF_MODIFIED_SINCE')) + return parse_date(self.environ.get("HTTP_IF_MODIFIED_SINCE")) @cached_property def if_unmodified_since(self): """The parsed `If-Unmodified-Since` header as datetime object.""" - return parse_date(self.environ.get('HTTP_IF_UNMODIFIED_SINCE')) + return parse_date(self.environ.get("HTTP_IF_UNMODIFIED_SINCE")) @cached_property def if_range(self): @@ -69,7 +68,7 @@ class ETagRequestMixin(object): :rtype: :class:`~werkzeug.datastructures.IfRange` """ - return parse_if_range_header(self.environ.get('HTTP_IF_RANGE')) + return parse_if_range_header(self.environ.get("HTTP_IF_RANGE")) @cached_property def range(self): @@ -79,7 +78,7 @@ class ETagRequestMixin(object): :rtype: :class:`~werkzeug.datastructures.Range` """ - return parse_range_header(self.environ.get('HTTP_RANGE')) + return parse_range_header(self.environ.get("HTTP_RANGE")) class ETagResponseMixin(object): @@ -99,14 +98,16 @@ class ETagResponseMixin(object): directives that MUST be obeyed by all caching mechanisms along the request/response chain. """ + def on_update(cache_control): - if not cache_control and 'cache-control' in self.headers: - del self.headers['cache-control'] + if not cache_control and "cache-control" in self.headers: + del self.headers["cache-control"] elif cache_control: - self.headers['Cache-Control'] = cache_control.to_header() - return parse_cache_control_header(self.headers.get('cache-control'), - on_update, - ResponseCacheControl) + self.headers["Cache-Control"] = cache_control.to_header() + + return parse_cache_control_header( + self.headers.get("cache-control"), on_update, ResponseCacheControl + ) def _wrap_response(self, start, length): """Wrap existing Response in case of Range Request context.""" @@ -118,12 +119,15 @@ class ETagResponseMixin(object): resource is considered unchanged when compared with `If-Range` header. """ return ( - 'HTTP_IF_RANGE' not in environ + "HTTP_IF_RANGE" not in environ or not is_resource_modified( - environ, self.headers.get('etag'), None, - self.headers.get('last-modified'), ignore_if_range=False + environ, + self.headers.get("etag"), + None, + self.headers.get("last-modified"), + ignore_if_range=False, ) - ) and 'HTTP_RANGE' in environ + ) and "HTTP_RANGE" in environ def _process_range_request(self, environ, complete_length=None, accept_ranges=None): """Handle Range Request related headers (RFC7233). If `Accept-Ranges` @@ -136,13 +140,14 @@ class ETagResponseMixin(object): :raises: :class:`~werkzeug.exceptions.RequestedRangeNotSatisfiable` if `Range` header could not be parsed or satisfied. """ - from werkzeug.exceptions import RequestedRangeNotSatisfiable + from ..exceptions import RequestedRangeNotSatisfiable + if accept_ranges is None: return False - self.headers['Accept-Ranges'] = accept_ranges + self.headers["Accept-Ranges"] = accept_ranges if not self._is_range_request_processable(environ) or complete_length is None: return False - parsed_range = parse_range_header(environ.get('HTTP_RANGE')) + parsed_range = parse_range_header(environ.get("HTTP_RANGE")) if parsed_range is None: raise RequestedRangeNotSatisfiable(complete_length) range_tuple = parsed_range.range_for_length(complete_length) @@ -153,15 +158,16 @@ class ETagResponseMixin(object): # Be sure not to send 206 response # if requested range is the full content. if content_length != complete_length: - self.headers['Content-Length'] = content_length + self.headers["Content-Length"] = content_length self.content_range = content_range_header self.status_code = 206 self._wrap_response(range_tuple[0], content_length) return True return False - def make_conditional(self, request_or_environ, accept_ranges=False, - complete_length=None): + def make_conditional( + self, request_or_environ, accept_ranges=False, complete_length=None + ): """Make the response conditional to the request. This method works best if an etag was defined for the response already. The `add_etag` method can be used to do that. If called without etag just the date @@ -199,43 +205,48 @@ class ETagResponseMixin(object): if `Range` header could not be parsed or satisfied. """ environ = _get_environ(request_or_environ) - if environ['REQUEST_METHOD'] in ('GET', 'HEAD'): + if environ["REQUEST_METHOD"] in ("GET", "HEAD"): # if the date is not in the headers, add it now. We however # will not override an already existing header. Unfortunately # this header will be overriden by many WSGI servers including # wsgiref. - if 'date' not in self.headers: - self.headers['Date'] = http_date() + if "date" not in self.headers: + self.headers["Date"] = http_date() accept_ranges = _clean_accept_ranges(accept_ranges) is206 = self._process_range_request(environ, complete_length, accept_ranges) if not is206 and not is_resource_modified( - environ, self.headers.get('etag'), None, - self.headers.get('last-modified') + environ, + self.headers.get("etag"), + None, + self.headers.get("last-modified"), ): - if parse_etags(environ.get('HTTP_IF_MATCH')): + if parse_etags(environ.get("HTTP_IF_MATCH")): self.status_code = 412 else: self.status_code = 304 - if self.automatically_set_content_length and 'content-length' not in self.headers: + if ( + self.automatically_set_content_length + and "content-length" not in self.headers + ): length = self.calculate_content_length() if length is not None: - self.headers['Content-Length'] = length + self.headers["Content-Length"] = length return self def add_etag(self, overwrite=False, weak=False): """Add an etag for the current response if there is none yet.""" - if overwrite or 'etag' not in self.headers: + if overwrite or "etag" not in self.headers: self.set_etag(generate_etag(self.get_data()), weak) def set_etag(self, etag, weak=False): """Set the etag, and override the old one if there was one.""" - self.headers['ETag'] = quote_etag(etag, weak) + self.headers["ETag"] = quote_etag(etag, weak) def get_etag(self): """Return a tuple in the form ``(etag, is_weak)``. If there is no ETag the return value is ``(None, None)``. """ - return unquote_etag(self.headers.get('ETag')) + return unquote_etag(self.headers.get("ETag")) def freeze(self, no_etag=False): """Call this method if you want to make your response object ready for @@ -246,22 +257,25 @@ class ETagResponseMixin(object): self.add_etag() super(ETagResponseMixin, self).freeze() - accept_ranges = header_property('Accept-Ranges', doc=''' - The `Accept-Ranges` header. Even though the name would indicate - that multiple values are supported, it must be one string token only. + accept_ranges = header_property( + "Accept-Ranges", + doc="""The `Accept-Ranges` header. Even though the name would + indicate that multiple values are supported, it must be one + string token only. The values ``'bytes'`` and ``'none'`` are common. - .. versionadded:: 0.7''') + .. versionadded:: 0.7""", + ) def _get_content_range(self): def on_update(rng): if not rng: - del self.headers['content-range'] + del self.headers["content-range"] else: - self.headers['Content-Range'] = rng.to_header() - rv = parse_content_range_header(self.headers.get('content-range'), - on_update) + self.headers["Content-Range"] = rng.to_header() + + rv = parse_content_range_header(self.headers.get("content-range"), on_update) # always provide a content range object to make the descriptor # more user friendly. It provides an unset() method that can be # used to remove the header quickly. @@ -271,16 +285,20 @@ class ETagResponseMixin(object): def _set_content_range(self, value): if not value: - del self.headers['content-range'] + del self.headers["content-range"] elif isinstance(value, string_types): - self.headers['Content-Range'] = value + self.headers["Content-Range"] = value else: - self.headers['Content-Range'] = value.to_header() - content_range = property(_get_content_range, _set_content_range, doc=''' - The `Content-Range` header as - :class:`~werkzeug.datastructures.ContentRange` object. Even if the - header is not set it wil provide such an object for easier + self.headers["Content-Range"] = value.to_header() + + content_range = property( + _get_content_range, + _set_content_range, + doc="""The ``Content-Range`` header as + :class:`~werkzeug.datastructures.ContentRange` object. Even if + the header is not set it wil provide such an object for easier manipulation. - .. versionadded:: 0.7''') + .. versionadded:: 0.7""", + ) del _get_content_range, _set_content_range diff --git a/src/werkzeug/wrappers/json.py b/src/werkzeug/wrappers/json.py index 851806dd..6d5dc33d 100644 --- a/src/werkzeug/wrappers/json.py +++ b/src/werkzeug/wrappers/json.py @@ -75,9 +75,9 @@ class JSONMixin(object): """ mt = self.mimetype return ( - mt == 'application/json' - or mt.startswith('application/') - and mt.endswith('+json') + mt == "application/json" + or mt.startswith("application/") + and mt.endswith("+json") ) def _get_data_for_json(self, cache): @@ -142,4 +142,4 @@ class JSONMixin(object): for :meth:`get_json`. The default implementation raises :exc:`~werkzeug.exceptions.BadRequest`. """ - raise BadRequest('Failed to decode JSON object: {0}'.format(e)) + raise BadRequest("Failed to decode JSON object: {0}".format(e)) diff --git a/src/werkzeug/wrappers/request.py b/src/werkzeug/wrappers/request.py index 5f1b00c8..d1c71b64 100644 --- a/src/werkzeug/wrappers/request.py +++ b/src/werkzeug/wrappers/request.py @@ -6,9 +6,14 @@ from .etag import ETagRequestMixin from .user_agent import UserAgentMixin -class Request(BaseRequest, AcceptMixin, ETagRequestMixin, - UserAgentMixin, AuthorizationMixin, - CommonRequestDescriptorsMixin): +class Request( + BaseRequest, + AcceptMixin, + ETagRequestMixin, + UserAgentMixin, + AuthorizationMixin, + CommonRequestDescriptorsMixin, +): """Full featured request object implementing the following mixins: - :class:`AcceptMixin` for accept header parsing diff --git a/src/werkzeug/wrappers/response.py b/src/werkzeug/wrappers/response.py index 8b7dc7ba..cd86cacd 100644 --- a/src/werkzeug/wrappers/response.py +++ b/src/werkzeug/wrappers/response.py @@ -11,7 +11,7 @@ class ResponseStream(object): iterable of the response object. """ - mode = 'wb+' + mode = "wb+" def __init__(self, response): self.response = response @@ -19,10 +19,10 @@ class ResponseStream(object): def write(self, value): if self.closed: - raise ValueError('I/O operation on closed file') + raise ValueError("I/O operation on closed file") self.response._ensure_sequence(mutable=True) self.response.response.append(value) - self.response.headers.pop('Content-Length', None) + self.response.headers.pop("Content-Length", None) return len(value) def writelines(self, seq): @@ -34,11 +34,11 @@ class ResponseStream(object): def flush(self): if self.closed: - raise ValueError('I/O operation on closed file') + raise ValueError("I/O operation on closed file") def isatty(self): if self.closed: - raise ValueError('I/O operation on closed file') + raise ValueError("I/O operation on closed file") return False def tell(self): @@ -62,9 +62,13 @@ class ResponseStreamMixin(object): return ResponseStream(self) -class Response(BaseResponse, ETagResponseMixin, ResponseStreamMixin, - CommonResponseDescriptorsMixin, - WWWAuthenticateMixin): +class Response( + BaseResponse, + ETagResponseMixin, + ResponseStreamMixin, + CommonResponseDescriptorsMixin, + WWWAuthenticateMixin, +): """Full featured response object implementing the following mixins: - :class:`ETagResponseMixin` for etag and cache control handling diff --git a/src/werkzeug/wrappers/user_agent.py b/src/werkzeug/wrappers/user_agent.py index be106efe..72588dd9 100644 --- a/src/werkzeug/wrappers/user_agent.py +++ b/src/werkzeug/wrappers/user_agent.py @@ -2,13 +2,14 @@ from ..utils import cached_property class UserAgentMixin(object): - """Adds a `user_agent` attribute to the request object which contains the - parsed user agent of the browser that triggered the request as a - :class:`~werkzeug.useragents.UserAgent` object. + """Adds a `user_agent` attribute to the request object which + contains the parsed user agent of the browser that triggered the + request as a :class:`~werkzeug.useragents.UserAgent` object. """ @cached_property def user_agent(self): """The current user agent.""" - from werkzeug.useragents import UserAgent + from ..useragents import UserAgent + return UserAgent(self.environ) diff --git a/src/werkzeug/wsgi.py b/src/werkzeug/wsgi.py index 60966503..f069f2d8 100644 --- a/src/werkzeug/wsgi.py +++ b/src/werkzeug/wsgi.py @@ -11,14 +11,24 @@ import io import re import warnings -from functools import partial, update_wrapper +from functools import partial +from functools import update_wrapper from itertools import chain -from werkzeug._compat import BytesIO, implements_iterator, \ - make_literal_wrapper, string_types, text_type, to_bytes, to_unicode, \ - try_coerce_native, wsgi_get_bytes -from werkzeug._internal import _encode_idna -from werkzeug.urls import uri_to_iri, url_join, url_parse, url_quote +from ._compat import BytesIO +from ._compat import implements_iterator +from ._compat import make_literal_wrapper +from ._compat import string_types +from ._compat import text_type +from ._compat import to_bytes +from ._compat import to_unicode +from ._compat import try_coerce_native +from ._compat import wsgi_get_bytes +from ._internal import _encode_idna +from .urls import uri_to_iri +from .urls import url_join +from .urls import url_parse +from .urls import url_quote def responder(f): @@ -34,8 +44,13 @@ def responder(f): return update_wrapper(lambda *a: f(*a)(*a[-2:]), f) -def get_current_url(environ, root_only=False, strip_querystring=False, - host_only=False, trusted_hosts=None): +def get_current_url( + environ, + root_only=False, + strip_querystring=False, + host_only=False, + trusted_hosts=None, +): """A handy helper function that recreates the full URL as IRI for the current request or parts of it. Here's an example: @@ -70,19 +85,19 @@ def get_current_url(environ, root_only=False, strip_querystring=False, :param trusted_hosts: a list of trusted hosts, see :func:`host_is_trusted` for more information. """ - tmp = [environ['wsgi.url_scheme'], '://', get_host(environ, trusted_hosts)] + tmp = [environ["wsgi.url_scheme"], "://", get_host(environ, trusted_hosts)] cat = tmp.append if host_only: - return uri_to_iri(''.join(tmp) + '/') - cat(url_quote(wsgi_get_bytes(environ.get('SCRIPT_NAME', ''))).rstrip('/')) - cat('/') + return uri_to_iri("".join(tmp) + "/") + cat(url_quote(wsgi_get_bytes(environ.get("SCRIPT_NAME", ""))).rstrip("/")) + cat("/") if not root_only: - cat(url_quote(wsgi_get_bytes(environ.get('PATH_INFO', '')).lstrip(b'/'))) + cat(url_quote(wsgi_get_bytes(environ.get("PATH_INFO", "")).lstrip(b"/"))) if not strip_querystring: qs = get_query_string(environ) if qs: - cat('?' + qs) - return uri_to_iri(''.join(tmp)) + cat("?" + qs) + return uri_to_iri("".join(tmp)) def host_is_trusted(hostname, trusted_list): @@ -103,8 +118,8 @@ def host_is_trusted(hostname, trusted_list): trusted_list = [trusted_list] def _normalize(hostname): - if ':' in hostname: - hostname = hostname.rsplit(':', 1)[0] + if ":" in hostname: + hostname = hostname.rsplit(":", 1)[0] return _encode_idna(hostname) try: @@ -112,7 +127,7 @@ def host_is_trusted(hostname, trusted_list): except UnicodeError: return False for ref in trusted_list: - if ref.startswith('.'): + if ref.startswith("."): ref = ref[1:] suffix_match = True else: @@ -123,7 +138,7 @@ def host_is_trusted(hostname, trusted_list): return False if ref == hostname: return True - if suffix_match and hostname.endswith(b'.' + ref): + if suffix_match and hostname.endswith(b"." + ref): return True return False @@ -144,20 +159,23 @@ def get_host(environ, trusted_hosts=None): :raise ~werkzeug.exceptions.SecurityError: If the host is not trusted. """ - if 'HTTP_HOST' in environ: - rv = environ['HTTP_HOST'] - if environ['wsgi.url_scheme'] == 'http' and rv.endswith(':80'): + if "HTTP_HOST" in environ: + rv = environ["HTTP_HOST"] + if environ["wsgi.url_scheme"] == "http" and rv.endswith(":80"): rv = rv[:-3] - elif environ['wsgi.url_scheme'] == 'https' and rv.endswith(':443'): + elif environ["wsgi.url_scheme"] == "https" and rv.endswith(":443"): rv = rv[:-4] else: - rv = environ['SERVER_NAME'] - if (environ['wsgi.url_scheme'], environ['SERVER_PORT']) not \ - in (('https', '443'), ('http', '80')): - rv += ':' + environ['SERVER_PORT'] + rv = environ["SERVER_NAME"] + if (environ["wsgi.url_scheme"], environ["SERVER_PORT"]) not in ( + ("https", "443"), + ("http", "80"), + ): + rv += ":" + environ["SERVER_PORT"] if trusted_hosts is not None: if not host_is_trusted(rv, trusted_hosts): - from werkzeug.exceptions import SecurityError + from .exceptions import SecurityError + raise SecurityError('Host "%s" is not trusted' % rv) return rv @@ -171,10 +189,10 @@ def get_content_length(environ): :param environ: the WSGI environ to fetch the content length from. """ - if environ.get('HTTP_TRANSFER_ENCODING', '') == 'chunked': + if environ.get("HTTP_TRANSFER_ENCODING", "") == "chunked": return None - content_length = environ.get('CONTENT_LENGTH') + content_length = environ.get("CONTENT_LENGTH") if content_length is not None: try: return max(0, int(content_length)) @@ -199,13 +217,13 @@ def get_input_stream(environ, safe_fallback=True): content length is not set. Disabling this allows infinite streams, which can be a denial-of-service risk. """ - stream = environ['wsgi.input'] + stream = environ["wsgi.input"] content_length = get_content_length(environ) # A wsgi extension that tells us if the input is terminated. In # that case we return the stream unchanged as we know we can safely # read it until the end. - if environ.get('wsgi.input_terminated'): + if environ.get("wsgi.input_terminated"): return stream # If the request doesn't specify a content length, returning the stream is @@ -228,14 +246,14 @@ def get_query_string(environ): :param environ: the WSGI environment object to get the query string from. """ - qs = wsgi_get_bytes(environ.get('QUERY_STRING', '')) + qs = wsgi_get_bytes(environ.get("QUERY_STRING", "")) # QUERY_STRING really should be ascii safe but some browsers # will send us some unicode stuff (I am looking at you IE). # In that case we want to urllib quote it badly. - return try_coerce_native(url_quote(qs, safe=':&%=+$!*\'(),')) + return try_coerce_native(url_quote(qs, safe=":&%=+$!*'(),")) -def get_path_info(environ, charset='utf-8', errors='replace'): +def get_path_info(environ, charset="utf-8", errors="replace"): """Returns the `PATH_INFO` from the WSGI environment and properly decodes it. This also takes care about the WSGI decoding dance on Python 3 environments. if the `charset` is set to `None` a @@ -248,11 +266,11 @@ def get_path_info(environ, charset='utf-8', errors='replace'): decoding should be performed. :param errors: the decoding error handling. """ - path = wsgi_get_bytes(environ.get('PATH_INFO', '')) + path = wsgi_get_bytes(environ.get("PATH_INFO", "")) return to_unicode(path, charset, errors, allow_none_charset=True) -def get_script_name(environ, charset='utf-8', errors='replace'): +def get_script_name(environ, charset="utf-8", errors="replace"): """Returns the `SCRIPT_NAME` from the WSGI environment and properly decodes it. This also takes care about the WSGI decoding dance on Python 3 environments. if the `charset` is set to `None` a @@ -265,11 +283,11 @@ def get_script_name(environ, charset='utf-8', errors='replace'): decoding should be performed. :param errors: the decoding error handling. """ - path = wsgi_get_bytes(environ.get('SCRIPT_NAME', '')) + path = wsgi_get_bytes(environ.get("SCRIPT_NAME", "")) return to_unicode(path, charset, errors, allow_none_charset=True) -def pop_path_info(environ, charset='utf-8', errors='replace'): +def pop_path_info(environ, charset="utf-8", errors="replace"): """Removes and returns the next segment of `PATH_INFO`, pushing it onto `SCRIPT_NAME`. Returns `None` if there is nothing left on `PATH_INFO`. @@ -296,32 +314,32 @@ def pop_path_info(environ, charset='utf-8', errors='replace'): :param environ: the WSGI environment that is modified. """ - path = environ.get('PATH_INFO') + path = environ.get("PATH_INFO") if not path: return None - script_name = environ.get('SCRIPT_NAME', '') + script_name = environ.get("SCRIPT_NAME", "") # shift multiple leading slashes over old_path = path - path = path.lstrip('/') + path = path.lstrip("/") if path != old_path: - script_name += '/' * (len(old_path) - len(path)) + script_name += "/" * (len(old_path) - len(path)) - if '/' not in path: - environ['PATH_INFO'] = '' - environ['SCRIPT_NAME'] = script_name + path + if "/" not in path: + environ["PATH_INFO"] = "" + environ["SCRIPT_NAME"] = script_name + path rv = wsgi_get_bytes(path) else: - segment, path = path.split('/', 1) - environ['PATH_INFO'] = '/' + path - environ['SCRIPT_NAME'] = script_name + segment + segment, path = path.split("/", 1) + environ["PATH_INFO"] = "/" + path + environ["SCRIPT_NAME"] = script_name + segment rv = wsgi_get_bytes(segment) return to_unicode(rv, charset, errors, allow_none_charset=True) -def peek_path_info(environ, charset='utf-8', errors='replace'): +def peek_path_info(environ, charset="utf-8", errors="replace"): """Returns the next segment on the `PATH_INFO` or `None` if there is none. Works like :func:`pop_path_info` without modifying the environment: @@ -342,18 +360,19 @@ def peek_path_info(environ, charset='utf-8', errors='replace'): :param environ: the WSGI environment that is checked. """ - segments = environ.get('PATH_INFO', '').lstrip('/').split('/', 1) + segments = environ.get("PATH_INFO", "").lstrip("/").split("/", 1) if segments: - return to_unicode(wsgi_get_bytes(segments[0]), - charset, errors, allow_none_charset=True) + return to_unicode( + wsgi_get_bytes(segments[0]), charset, errors, allow_none_charset=True + ) def extract_path_info( environ_or_baseurl, path_or_url, - charset='utf-8', - errors='werkzeug.url_quote', - collapse_http_schemes=True + charset="utf-8", + errors="werkzeug.url_quote", + collapse_http_schemes=True, ): """Extracts the path info from the given URL (or WSGI environment) and path. The path info returned is a unicode string, not a bytestring @@ -395,29 +414,29 @@ def extract_path_info( .. versionadded:: 0.6 """ + def _normalize_netloc(scheme, netloc): - parts = netloc.split(u'@', 1)[-1].split(u':', 1) + parts = netloc.split(u"@", 1)[-1].split(u":", 1) if len(parts) == 2: netloc, port = parts - if (scheme == u'http' and port == u'80') or \ - (scheme == u'https' and port == u'443'): + if (scheme == u"http" and port == u"80") or ( + scheme == u"https" and port == u"443" + ): port = None else: netloc = parts[0] port = None if port is not None: - netloc += u':' + port + netloc += u":" + port return netloc # make sure whatever we are working on is a IRI and parse it path = uri_to_iri(path_or_url, charset, errors) if isinstance(environ_or_baseurl, dict): - environ_or_baseurl = get_current_url(environ_or_baseurl, - root_only=True) + environ_or_baseurl = get_current_url(environ_or_baseurl, root_only=True) base_iri = uri_to_iri(environ_or_baseurl, charset, errors) base_scheme, base_netloc, base_path = url_parse(base_iri)[:3] - cur_scheme, cur_netloc, cur_path, = \ - url_parse(url_join(base_iri, path))[:3] + cur_scheme, cur_netloc, cur_path, = url_parse(url_join(base_iri, path))[:3] # normalize the network location base_netloc = _normalize_netloc(base_scheme, base_netloc) @@ -426,11 +445,10 @@ def extract_path_info( # is that IRI even on a known HTTP scheme? if collapse_http_schemes: for scheme in base_scheme, cur_scheme: - if scheme not in (u'http', u'https'): + if scheme not in (u"http", u"https"): return None else: - if not (base_scheme in (u'http', u'https') - and base_scheme == cur_scheme): + if not (base_scheme in (u"http", u"https") and base_scheme == cur_scheme): return None # are the netlocs compatible? @@ -438,16 +456,15 @@ def extract_path_info( return None # are we below the application path? - base_path = base_path.rstrip(u'/') + base_path = base_path.rstrip(u"/") if not cur_path.startswith(base_path): return None - return u'/' + cur_path[len(base_path):].lstrip(u'/') + return u"/" + cur_path[len(base_path) :].lstrip(u"/") @implements_iterator class ClosingIterator(object): - """The WSGI specification requires that all middlewares and gateways respect the `close` callback of the iterable returned by the application. Because it is useful to add another close action to a returned iterable @@ -478,7 +495,7 @@ class ClosingIterator(object): callbacks = [callbacks] else: callbacks = list(callbacks) - iterable_close = getattr(iterable, 'close', None) + iterable_close = getattr(iterable, "close", None) if iterable_close: callbacks.insert(0, iterable_close) self._callbacks = callbacks @@ -510,12 +527,11 @@ def wrap_file(environ, file, buffer_size=8192): :param file: a :class:`file`-like object with a :meth:`~file.read` method. :param buffer_size: number of bytes for one iteration. """ - return environ.get('wsgi.file_wrapper', FileWrapper)(file, buffer_size) + return environ.get("wsgi.file_wrapper", FileWrapper)(file, buffer_size) @implements_iterator class FileWrapper(object): - """This class can be used to convert a :class:`file`-like object into an iterable. It yields `buffer_size` blocks until the file is fully read. @@ -538,22 +554,22 @@ class FileWrapper(object): self.buffer_size = buffer_size def close(self): - if hasattr(self.file, 'close'): + if hasattr(self.file, "close"): self.file.close() def seekable(self): - if hasattr(self.file, 'seekable'): + if hasattr(self.file, "seekable"): return self.file.seekable() - if hasattr(self.file, 'seek'): + if hasattr(self.file, "seek"): return True return False def seek(self, *args): - if hasattr(self.file, 'seek'): + if hasattr(self.file, "seek"): self.file.seek(*args) def tell(self): - if hasattr(self.file, 'tell'): + if hasattr(self.file, "tell"): return self.file.tell() return None @@ -593,7 +609,7 @@ class _RangeWrapper(object): if byte_range is not None: self.end_byte = self.start_byte + self.byte_range self.read_length = 0 - self.seekable = hasattr(iterable, 'seekable') and iterable.seekable() + self.seekable = hasattr(iterable, "seekable") and iterable.seekable() self.end_reached = False def __iter__(self): @@ -618,7 +634,7 @@ class _RangeWrapper(object): while self.read_length <= self.start_byte: chunk = self._next_chunk() if chunk is not None: - chunk = chunk[self.start_byte - self.read_length:] + chunk = chunk[self.start_byte - self.read_length :] contextual_read_length = self.start_byte return chunk, contextual_read_length @@ -633,7 +649,7 @@ class _RangeWrapper(object): chunk = self._next_chunk() if self.end_byte is not None and self.read_length >= self.end_byte: self.end_reached = True - return chunk[:self.end_byte - contextual_read_length] + return chunk[: self.end_byte - contextual_read_length] return chunk def __next__(self): @@ -644,16 +660,17 @@ class _RangeWrapper(object): raise StopIteration() def close(self): - if hasattr(self.iterable, 'close'): + if hasattr(self.iterable, "close"): self.iterable.close() def _make_chunk_iter(stream, limit, buffer_size): """Helper for the line and chunk iter functions.""" if isinstance(stream, (bytes, bytearray, text_type)): - raise TypeError('Passed a string or byte object instead of ' - 'true iterator or stream.') - if not hasattr(stream, 'read'): + raise TypeError( + "Passed a string or byte object instead of true iterator or stream." + ) + if not hasattr(stream, "read"): for item in stream: if item: yield item @@ -668,8 +685,7 @@ def _make_chunk_iter(stream, limit, buffer_size): yield item -def make_line_iter(stream, limit=None, buffer_size=10 * 1024, - cap_at_buffer=False): +def make_line_iter(stream, limit=None, buffer_size=10 * 1024, cap_at_buffer=False): """Safely iterates line-based over an input stream. If the input stream is not a :class:`LimitedStream` the `limit` parameter is mandatory. @@ -703,15 +719,15 @@ def make_line_iter(stream, limit=None, buffer_size=10 * 1024, """ _iter = _make_chunk_iter(stream, limit, buffer_size) - first_item = next(_iter, '') + first_item = next(_iter, "") if not first_item: return s = make_literal_wrapper(first_item) - empty = s('') - cr = s('\r') - lf = s('\n') - crlf = s('\r\n') + empty = s("") + cr = s("\r") + lf = s("\n") + crlf = s("\r\n") _iter = chain((first_item,), _iter) @@ -719,7 +735,7 @@ def make_line_iter(stream, limit=None, buffer_size=10 * 1024, _join = empty.join buffer = [] while 1: - new_data = next(_iter, '') + new_data = next(_iter, "") if not new_data: break new_buf = [] @@ -754,8 +770,9 @@ def make_line_iter(stream, limit=None, buffer_size=10 * 1024, yield previous -def make_chunk_iter(stream, separator, limit=None, buffer_size=10 * 1024, - cap_at_buffer=False): +def make_chunk_iter( + stream, separator, limit=None, buffer_size=10 * 1024, cap_at_buffer=False +): """Works like :func:`make_line_iter` but accepts a separator which divides chunks. If you want newline based processing you should use :func:`make_line_iter` instead as it @@ -782,23 +799,23 @@ def make_chunk_iter(stream, separator, limit=None, buffer_size=10 * 1024, """ _iter = _make_chunk_iter(stream, limit, buffer_size) - first_item = next(_iter, '') + first_item = next(_iter, "") if not first_item: return _iter = chain((first_item,), _iter) if isinstance(first_item, text_type): separator = to_unicode(separator) - _split = re.compile(r'(%s)' % re.escape(separator)).split - _join = u''.join + _split = re.compile(r"(%s)" % re.escape(separator)).split + _join = u"".join else: separator = to_bytes(separator) - _split = re.compile(b'(' + re.escape(separator) + b')').split - _join = b''.join + _split = re.compile(b"(" + re.escape(separator) + b")").split + _join = b"".join buffer = [] while 1: - new_data = next(_iter, '') + new_data = next(_iter, "") if not new_data: break chunks = _split(new_data) @@ -828,7 +845,6 @@ def make_chunk_iter(stream, separator, limit=None, buffer_size=10 * 1024, @implements_iterator class LimitedStream(io.IOBase): - """Wraps a stream so that it doesn't read more than n bytes. If the stream is exhausted and the caller tries to get more bytes from it :func:`on_exhausted` is called which by default returns an empty @@ -891,7 +907,8 @@ class LimitedStream(io.IOBase): the client went away. By default a :exc:`~werkzeug.exceptions.ClientDisconnected` exception is raised. """ - from werkzeug.exceptions import ClientDisconnected + from .exceptions import ClientDisconnected + raise ClientDisconnected() def exhaust(self, chunk_size=1024 * 64): @@ -984,9 +1001,10 @@ class LimitedStream(io.IOBase): return True -from werkzeug.middleware.dispatcher import DispatcherMiddleware as _DispatcherMiddleware -from werkzeug.middleware.http_proxy import ProxyMiddleware as _ProxyMiddleware -from werkzeug.middleware.shared_data import SharedDataMiddleware as _SharedDataMiddleware +# DEPRECATED +from .middleware.dispatcher import DispatcherMiddleware as _DispatcherMiddleware +from .middleware.http_proxy import ProxyMiddleware as _ProxyMiddleware +from .middleware.shared_data import SharedDataMiddleware as _SharedDataMiddleware class ProxyMiddleware(_ProxyMiddleware): |
