diff options
| author | David Lord <davidism@gmail.com> | 2019-02-13 11:44:18 -0800 |
|---|---|---|
| committer | David Lord <davidism@gmail.com> | 2019-03-08 08:01:31 -0800 |
| commit | ab6150fa49afc61b0c5eed6d9545d03d1958e384 (patch) | |
| tree | ad5f13c9c2775ca59cc8e82ec124c4e065a65d1b | |
| parent | 048d707d25685e6aea675c53945ceb7619e60344 (diff) | |
| download | werkzeug-code-style.tar.gz | |
apply code stylecode-style
* reorder-python-imports
* line fixers
* black
* flake8
151 files changed, 11013 insertions, 8798 deletions
diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 00000000..e32c8029 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,13 @@ +root = true + +[*] +indent_style = space +indent_size = 4 +insert_final_newline = true +trim_trailing_whitespace = true +end_of_line = lf +charset = utf-8 +max_line_length = 88 + +[*.{yml,yaml,json,js,css,html}] +indent_size = 2 diff --git a/.gitattributes b/.gitattributes index 96d79dad..5946e823 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1 @@ -tests/res/chunked.txt binary - +tests/**/*.http binary diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..2fb46619 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,28 @@ +repos: + - repo: https://github.com/asottile/reorder_python_imports + rev: v1.4.0 + hooks: + - id: reorder-python-imports + name: Reorder Python imports (src, tests) + files: "^(?!examples/)" + args: ["--application-directories", ".:src"] + - id: reorder-python-imports + name: Reorder Python imports (examples) + files: "^examples/" + args: ["--application-directories", "examples"] + - repo: https://github.com/ambv/black + rev: 18.9b0 + hooks: + - id: black + - repo: https://gitlab.com/pycqa/flake8 + rev: 3.7.7 + hooks: + - id: flake8 + additional_dependencies: [flake8-bugbear] + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.1.0 + hooks: + - id: check-byte-order-marker + - id: trailing-whitespace + - id: end-of-file-fixer + exclude: "^tests/.*.http$" diff --git a/bench/wzbench.py b/bench/wzbench.py index f75742f7..6270b620 100755 --- a/bench/wzbench.py +++ b/bench/wzbench.py @@ -11,17 +11,20 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -from __future__ import division, print_function -import os +from __future__ import division +from __future__ import print_function + import gc -import sys +import os import subprocess +import sys +from timeit import default_timer as timer +from types import FunctionType + try: from cStringIO import StringIO except ImportError: from io import StringIO -from timeit import default_timer as timer -from types import FunctionType PY2 = sys.version_info[0] == 2 @@ -30,10 +33,9 @@ if not PY2: # create a new module where we later store all the werkzeug attributes. -wz = type(sys)('werkzeug_nonlazy') -sys.path.insert(0, '<DUMMY>') -null_out = open(os.devnull, 'w') - +wz = type(sys)("werkzeug_nonlazy") +sys.path.insert(0, "<DUMMY>") +null_out = open(os.devnull, "w") # ±4% are ignored TOLERANCE = 0.04 @@ -47,8 +49,9 @@ def find_hg_tag(path): """Returns the current node or tag for the given path.""" tags = {} try: - client = subprocess.Popen(['hg', 'cat', '-r', 'tip', '.hgtags'], - stdout=subprocess.PIPE, cwd=path) + client = subprocess.Popen( + ["hg", "cat", "-r", "tip", ".hgtags"], stdout=subprocess.PIPE, cwd=path + ) for line in client.communicate()[0].splitlines(): line = line.strip() if not line: @@ -58,8 +61,9 @@ def find_hg_tag(path): except OSError: return - client = subprocess.Popen(['hg', 'parent', '--template', '#node#'], - stdout=subprocess.PIPE, cwd=path) + client = subprocess.Popen( + ["hg", "parent", "--template", "#node#"], stdout=subprocess.PIPE, cwd=path + ) tip = client.communicate()[0].strip() tag = tags.get(tip) @@ -75,11 +79,12 @@ def load_werkzeug(path): # get rid of already imported stuff wz.__dict__.clear() for key in sys.modules.keys(): - if key.startswith('werkzeug.') or key == 'werkzeug': + if key.startswith("werkzeug.") or key == "werkzeug": sys.modules.pop(key, None) # import werkzeug again. import werkzeug + for key in werkzeug.__all__: setattr(wz, key, getattr(werkzeug, key)) @@ -88,18 +93,18 @@ def load_werkzeug(path): # get the real version from the setup file try: - f = open(os.path.join(path, 'setup.py')) + f = open(os.path.join(path, "setup.py")) except IOError: pass else: try: for line in f: line = line.strip() - if line.startswith('version='): - return line[8:].strip(' \t,')[1:-1], hg_tag + if line.startswith("version="): + return line[8:].strip(" \t,")[1:-1], hg_tag finally: f.close() - print('Unknown werkzeug version loaded', file=sys.stderr) + print("Unknown werkzeug version loaded", file=sys.stderr) sys.exit(2) @@ -115,14 +120,14 @@ def format_func(func): name = func.__name__ else: name = func - if name.startswith('time_'): + if name.startswith("time_"): name = name[5:] - return name.replace('_', ' ').title() + return name.replace("_", " ").title() def bench(func): """Times a single function.""" - sys.stdout.write('%44s ' % format_func(func)) + sys.stdout.write("%44s " % format_func(func)) sys.stdout.flush() # figure out how many times we have to run the function to @@ -130,7 +135,7 @@ def bench(func): for i in xrange(3, 10): rounds = 1 << i t = timer() - for x in xrange(rounds): + for _ in xrange(rounds): func() if timer() - t >= 0.2: break @@ -142,14 +147,14 @@ def bench(func): gc.disable() try: t = timer() - for x in xrange(rounds): + for _ in xrange(rounds): func() return (timer() - t) / rounds * 1000 finally: gc.enable() delta = median(_run() for x in xrange(TEST_RUNS)) - sys.stdout.write('%.4f\n' % delta) + sys.stdout.write("%.4f\n" % delta) sys.stdout.flush() return delta @@ -158,17 +163,33 @@ def bench(func): def main(): """The main entrypoint.""" from optparse import OptionParser - parser = OptionParser(usage='%prog [options]') - parser.add_option('--werkzeug-path', '-p', dest='path', default='..', - help='the path to the werkzeug package. defaults to cwd') - parser.add_option('--compare', '-c', dest='compare', nargs=2, - default=False, help='compare two hg nodes of Werkzeug') - parser.add_option('--init-compare', dest='init_compare', - action='store_true', default=False, - help='Initializes the comparison feature') + + parser = OptionParser(usage="%prog [options]") + parser.add_option( + "--werkzeug-path", + "-p", + dest="path", + default="..", + help="the path to the werkzeug package. defaults to cwd", + ) + parser.add_option( + "--compare", + "-c", + dest="compare", + nargs=2, + default=False, + help="compare two hg nodes of Werkzeug", + ) + parser.add_option( + "--init-compare", + dest="init_compare", + action="store_true", + default=False, + help="Initializes the comparison feature", + ) options, args = parser.parse_args() if args: - parser.error('Script takes no arguments') + parser.error("Script takes no arguments") if options.compare: compare(*options.compare) elif options.init_compare: @@ -179,65 +200,70 @@ def main(): def init_compare(): """Initializes the comparison feature.""" - print('Initializing comparison feature') - subprocess.Popen(['hg', 'clone', '..', 'a']).wait() - subprocess.Popen(['hg', 'clone', '..', 'b']).wait() + print("Initializing comparison feature") + subprocess.Popen(["hg", "clone", "..", "a"]).wait() + subprocess.Popen(["hg", "clone", "..", "b"]).wait() def compare(node1, node2): """Compares two Werkzeug hg versions.""" - if not os.path.isdir('a'): - print('error: comparison feature not initialized', file=sys.stderr) + if not os.path.isdir("a"): + print("error: comparison feature not initialized", file=sys.stderr) sys.exit(4) - print('=' * 80) - print('WERKZEUG INTERNAL BENCHMARK -- COMPARE MODE'.center(80)) - print('-' * 80) - - def _error(msg): - print('error:', msg, file=sys.stderr) - sys.exit(1) + print("=" * 80) + print("WERKZEUG INTERNAL BENCHMARK -- COMPARE MODE".center(80)) + print("-" * 80) def _hg_update(repo, node): - hg = lambda *x: subprocess.call(['hg'] + list(x), cwd=repo, - stdout=null_out, stderr=null_out) - hg('revert', '-a', '--no-backup') - client = subprocess.Popen(['hg', 'status', '--unknown', '-n', '-0'], - stdout=subprocess.PIPE, cwd=repo) + def hg(*x): + return subprocess.call( + ["hg"] + list(x), cwd=repo, stdout=null_out, stderr=null_out + ) + + hg("revert", "-a", "--no-backup") + client = subprocess.Popen( + ["hg", "status", "--unknown", "-n", "-0"], stdout=subprocess.PIPE, cwd=repo + ) unknown = client.communicate()[0] if unknown: - client = subprocess.Popen(['xargs', '-0', 'rm', '-f'], cwd=repo, - stdout=null_out, stdin=subprocess.PIPE) + client = subprocess.Popen( + ["xargs", "-0", "rm", "-f"], + cwd=repo, + stdout=null_out, + stdin=subprocess.PIPE, + ) client.communicate(unknown) - hg('pull', '../..') - hg('update', node) - if node == 'tip': - diff = subprocess.Popen(['hg', 'diff'], cwd='..', - stdout=subprocess.PIPE).communicate()[0] + hg("pull", "../..") + hg("update", node) + if node == "tip": + diff = subprocess.Popen( + ["hg", "diff"], cwd="..", stdout=subprocess.PIPE + ).communicate()[0] if diff: - client = subprocess.Popen(['hg', 'import', '--no-commit', '-'], - cwd=repo, stdout=null_out, - stdin=subprocess.PIPE) + client = subprocess.Popen( + ["hg", "import", "--no-commit", "-"], + cwd=repo, + stdout=null_out, + stdin=subprocess.PIPE, + ) client.communicate(diff) - _hg_update('a', node1) - _hg_update('b', node2) - d1 = run('a', no_header=True) - d2 = run('b', no_header=True) + _hg_update("a", node1) + _hg_update("b", node2) + d1 = run("a", no_header=True) + d2 = run("b", no_header=True) - print('DIRECT COMPARISON'.center(80)) - print('-' * 80) + print("DIRECT COMPARISON".center(80)) + print("-" * 80) for key in sorted(d1): delta = d1[key] - d2[key] - if abs(1 - d1[key] / d2[key]) < TOLERANCE or \ - abs(delta) < MIN_RESOLUTION: - delta = '==' + if abs(1 - d1[key] / d2[key]) < TOLERANCE or abs(delta) < MIN_RESOLUTION: + delta = "==" else: - delta = '%+.4f (%+d%%)' % \ - (delta, round(d2[key] / d1[key] * 100 - 100)) - print('%36s %.4f %.4f %s' % - (format_func(key), d1[key], d2[key], delta)) - print('-' * 80) + delta = "%+.4f (%+d%%)" % (delta, round(d2[key] / d1[key] * 100 - 100)) + print("%36s %.4f %.4f %s" % (format_func(key), d1[key], d2[key], delta)) + print("-" * 80) def run(path, no_header=False): @@ -245,45 +271,47 @@ def run(path, no_header=False): wz_version, hg_tag = load_werkzeug(path) result = {} if not no_header: - print('=' * 80) - print('WERKZEUG INTERNAL BENCHMARK'.center(80)) - print('-' * 80) - print('Path: %s' % path) - print('Version: %s' % wz_version) + print("=" * 80) + print("WERKZEUG INTERNAL BENCHMARK".center(80)) + print("-" * 80) + print("Path: %s" % path) + print("Version: %s" % wz_version) if hg_tag is not None: - print('HG Tag: %s' % hg_tag) - print('-' * 80) + print("HG Tag: %s" % hg_tag) + print("-" * 80) for key, value in sorted(globals().items()): - if key.startswith('time_'): - before = globals().get('before_' + key[5:]) + if key.startswith("time_"): + before = globals().get("before_" + key[5:]) if before: before() result[key] = bench(value) - after = globals().get('after_' + key[5:]) + after = globals().get("after_" + key[5:]) if after: after() - print('-' * 80) + print("-" * 80) return result URL_DECODED_DATA = dict((str(x), str(x)) for x in xrange(100)) -URL_ENCODED_DATA = '&'.join('%s=%s' % x for x in URL_DECODED_DATA.items()) -MULTIPART_ENCODED_DATA = '\n'.join(( - '--foo', - 'Content-Disposition: form-data; name=foo', - '', - 'this is just bar', - '--foo', - 'Content-Disposition: form-data; name=bar', - '', - 'blafasel', - '--foo', - 'Content-Disposition: form-data; name=foo; filename=wzbench.py', - 'Content-Type: text/plain', - '', - open(__file__.rstrip('c')).read(), - '--foo--' -)) +URL_ENCODED_DATA = "&".join("%s=%s" % x for x in URL_DECODED_DATA.items()) +MULTIPART_ENCODED_DATA = "\n".join( + ( + "--foo", + "Content-Disposition: form-data; name=foo", + "", + "this is just bar", + "--foo", + "Content-Disposition: form-data; name=bar", + "", + "blafasel", + "--foo", + "Content-Disposition: form-data; name=foo; filename=wzbench.py", + "Content-Type: text/plain", + "", + open(__file__.rstrip("c")).read(), + "--foo--", + ) +) MULTIDICT = None REQUEST = None TEST_ENV = None @@ -304,10 +332,10 @@ def time_parse_form_data_multipart(): # from_values which is known to be slowish in 0.5.1 and higher. # we don't want to bench two things at once. environ = { - 'REQUEST_METHOD': 'POST', - 'CONTENT_TYPE': 'multipart/form-data; boundary=foo', - 'wsgi.input': StringIO(MULTIPART_ENCODED_DATA), - 'CONTENT_LENGTH': str(len(MULTIPART_ENCODED_DATA)) + "REQUEST_METHOD": "POST", + "CONTENT_TYPE": "multipart/form-data; boundary=foo", + "wsgi.input": StringIO(MULTIPART_ENCODED_DATA), + "CONTENT_LENGTH": str(len(MULTIPART_ENCODED_DATA)), } request = wz.Request(environ) request.form @@ -315,11 +343,11 @@ def time_parse_form_data_multipart(): def before_multidict_lookup_hit(): global MULTIDICT - MULTIDICT = wz.MultiDict({'foo': 'bar'}) + MULTIDICT = wz.MultiDict({"foo": "bar"}) def time_multidict_lookup_hit(): - MULTIDICT['foo'] + MULTIDICT["foo"] def after_multidict_lookup_hit(): @@ -334,7 +362,7 @@ def before_multidict_lookup_miss(): def time_multidict_lookup_miss(): try: - MULTIDICT['foo'] + MULTIDICT["foo"] except KeyError: pass @@ -351,31 +379,33 @@ def time_cached_property(): return 42 f = Foo() - for x in xrange(60): + for _ in xrange(60): f.x def before_request_form_access(): global REQUEST - data = 'foo=bar&blah=blub' - REQUEST = wz.Request({ - 'CONTENT_LENGTH': str(len(data)), - 'wsgi.input': StringIO(data), - 'REQUEST_METHOD': 'POST', - 'wsgi.version': (1, 0), - 'QUERY_STRING': data, - 'CONTENT_TYPE': 'application/x-www-form-urlencoded', - 'PATH_INFO': '/', - 'SCRIPT_NAME': '' - }) + data = "foo=bar&blah=blub" + REQUEST = wz.Request( + { + "CONTENT_LENGTH": str(len(data)), + "wsgi.input": StringIO(data), + "REQUEST_METHOD": "POST", + "wsgi.version": (1, 0), + "QUERY_STRING": data, + "CONTENT_TYPE": "application/x-www-form-urlencoded", + "PATH_INFO": "/", + "SCRIPT_NAME": "", + } + ) def time_request_form_access(): - for x in xrange(30): + for _ in xrange(30): REQUEST.path REQUEST.script_root - REQUEST.args['foo'] - REQUEST.form['foo'] + REQUEST.args["foo"] + REQUEST.form["foo"] def after_request_form_access(): @@ -384,12 +414,14 @@ def after_request_form_access(): def time_request_from_values(): - wz.Request.from_values(base_url='http://www.google.com/', - query_string='foo=bar&blah=blaz', - input_stream=StringIO(MULTIPART_ENCODED_DATA), - content_length=len(MULTIPART_ENCODED_DATA), - content_type='multipart/form-data; ' - 'boundary=foo', method='POST') + wz.Request.from_values( + base_url="http://www.google.com/", + query_string="foo=bar&blah=blaz", + input_stream=StringIO(MULTIPART_ENCODED_DATA), + content_length=len(MULTIPART_ENCODED_DATA), + content_type="multipart/form-data; boundary=foo", + method="POST", + ) def before_request_shallow_init(): @@ -407,16 +439,14 @@ def after_request_shallow_init(): def time_response_iter_performance(): - resp = wz.Response(u'Hällo Wörld ' * 1000, - mimetype='text/html') - for item in resp({'REQUEST_METHOD': 'GET'}, lambda *s: None): + resp = wz.Response(u"Hällo Wörld " * 1000, mimetype="text/html") + for _ in resp({"REQUEST_METHOD": "GET"}, lambda *s: None): pass def time_response_iter_head_performance(): - resp = wz.Response(u'Hällo Wörld ' * 1000, - mimetype='text/html') - for item in resp({'REQUEST_METHOD': 'HEAD'}, lambda *s: None): + resp = wz.Response(u"Hällo Wörld " * 1000, mimetype="text/html") + for _ in resp({"REQUEST_METHOD": "HEAD"}, lambda *s: None): pass @@ -427,9 +457,9 @@ def before_local_manager_dispatch(): def time_local_manager_dispatch(): - for x in xrange(10): + for _ in xrange(10): LOCAL.x = 42 - for x in xrange(10): + for _ in xrange(10): LOCAL.x @@ -440,14 +470,14 @@ def after_local_manager_dispatch(): def before_html_builder(): global TABLE - TABLE = [['col 1', 'col 2', 'col 3', '4', '5', '6'] for x in range(10)] + TABLE = [["col 1", "col 2", "col 3", "4", "5", "6"] for x in range(10)] def time_html_builder(): html_rows = [] for row in TABLE: # noqa - html_cols = [wz.html.td(col, class_='col') for col in row] - html_rows.append(wz.html.tr(class_='row', *html_cols)) + html_cols = [wz.html.td(col, class_="col") for col in row] + html_rows.append(wz.html.tr(class_="row", *html_cols)) wz.html.table(*html_rows) @@ -456,9 +486,9 @@ def after_html_builder(): TABLE = None -if __name__ == '__main__': +if __name__ == "__main__": os.chdir(os.path.dirname(__file__) or os.path.curdir) try: main() except KeyboardInterrupt: - print('\nInterrupted!', file=sys.stderr) + print("\nInterrupted!", file=sys.stderr) diff --git a/docs/routing.rst b/docs/routing.rst index 9db3e177..f5a768cc 100644 --- a/docs/routing.rst +++ b/docs/routing.rst @@ -221,4 +221,3 @@ Variable parts are of course also possible in the host section:: Rule('/', endpoint='www_index', host='www.example.com'), Rule('/', endpoint='user_index', host='<user>.example.com') ], host_matching=True) - diff --git a/docs/test.rst b/docs/test.rst index eb0130b3..c7e213f8 100644 --- a/docs/test.rst +++ b/docs/test.rst @@ -139,7 +139,7 @@ Testing API A dict with values that are used to override the generated environ. .. attribute:: input_stream - + The optional input stream. This and :attr:`form` / :attr:`files` is mutually exclusive. Also do not provide this stream if the request method is not `POST` / `PUT` or something comparable. diff --git a/docs/unicode.rst b/docs/unicode.rst index 0dc977a7..d8eaa84c 100644 --- a/docs/unicode.rst +++ b/docs/unicode.rst @@ -95,7 +95,7 @@ argument that behaves like the `errors` parameter of the builtin string method Unlike the regular python decoding Werkzeug does not raise an :exc:`UnicodeDecodeError` if the decoding failed but an :exc:`~exceptions.HTTPUnicodeError` which -is a direct subclass of `UnicodeError` and the `BadRequest` HTTP exception. +is a direct subclass of `UnicodeError` and the `BadRequest` HTTP exception. The reason is that if this exception is not caught by the application but a catch-all for HTTP exceptions exists a default `400 BAD REQUEST` error page is displayed. diff --git a/examples/README.rst b/examples/README.rst index 593394a1..2b9df866 100644 --- a/examples/README.rst +++ b/examples/README.rst @@ -23,7 +23,7 @@ find in real life :-) A simple Wiki implementation. Requirements: - + - SQLAlchemy - Creoleparser >= 0.7 - genshi @@ -50,7 +50,7 @@ find in real life :-) A planet called plnt, pronounce plant. Requirements: - + - SQLAlchemy - Jinja2 - feedparser @@ -76,7 +76,7 @@ find in real life :-) A tinyurl clone for the Werkzeug tutorial. Requirements: - + - SQLAlchemy - Jinja2 @@ -98,7 +98,7 @@ find in real life :-) Like shorty, but implemented using CouchDB. Requirements : - + - werkzeug : http://werkzeug.pocoo.org - jinja : http://jinja.pocoo.org - couchdb 0.72 & above : https://couchdb.apache.org/ diff --git a/examples/contrib/securecookie.py b/examples/contrib/securecookie.py index a0a566d1..2f6544d3 100644 --- a/examples/contrib/securecookie.py +++ b/examples/contrib/securecookie.py @@ -9,15 +9,16 @@ :license: BSD-3-Clause """ from time import asctime -from werkzeug.serving import run_simple -from werkzeug.wrappers import BaseRequest, BaseResponse + from werkzeug.contrib.securecookie import SecureCookie +from werkzeug.serving import run_simple +from werkzeug.wrappers import BaseRequest +from werkzeug.wrappers import BaseResponse -SECRET_KEY = 'V\x8a$m\xda\xe9\xc3\x0f|f\x88\xbccj>\x8bI^3+' +SECRET_KEY = "V\x8a$m\xda\xe9\xc3\x0f|f\x88\xbccj>\x8bI^3+" class Request(BaseRequest): - def __init__(self, environ): BaseRequest.__init__(self, environ) self.session = SecureCookie.load_cookie(self, secret_key=SECRET_KEY) @@ -28,23 +29,23 @@ def index(request): def get_time(request): - return 'Time: %s' % request.session.get('time', 'not set') + return "Time: %s" % request.session.get("time", "not set") def set_time(request): - request.session['time'] = time = asctime() - return 'Time set to %s' % time + request.session["time"] = time = asctime() + return "Time set to %s" % time def application(environ, start_response): request = Request(environ) - response = BaseResponse({ - 'get': get_time, - 'set': set_time - }.get(request.path.strip('/'), index)(request), mimetype='text/html') + response = BaseResponse( + {"get": get_time, "set": set_time}.get(request.path.strip("/"), index)(request), + mimetype="text/html", + ) request.session.save_cookie(response) return response(environ, start_response) -if __name__ == '__main__': - run_simple('localhost', 5000, application) +if __name__ == "__main__": + run_simple("localhost", 5000, application) diff --git a/examples/contrib/sessions.py b/examples/contrib/sessions.py index ecf464a2..c5eef576 100644 --- a/examples/contrib/sessions.py +++ b/examples/contrib/sessions.py @@ -1,11 +1,11 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +from werkzeug.contrib.sessions import SessionMiddleware +from werkzeug.contrib.sessions import SessionStore from werkzeug.serving import run_simple -from werkzeug.contrib.sessions import SessionStore, SessionMiddleware class MemorySessionStore(SessionStore): - def __init__(self, session_class=None): SessionStore.__init__(self, session_class=None) self.sessions = {} @@ -23,21 +23,22 @@ class MemorySessionStore(SessionStore): def application(environ, start_response): - session = environ['werkzeug.session'] - session['visit_count'] = session.get('visit_count', 0) + 1 + session = environ["werkzeug.session"] + session["visit_count"] = session.get("visit_count", 0) + 1 - start_response('200 OK', [('Content-Type', 'text/html')]) - return [''' - <!doctype html> + start_response("200 OK", [("Content-Type", "text/html")]) + return [ + """<!doctype html> <title>Session Example</title> <h1>Session Example</h1> - <p>You visited this page %d times.</p> - ''' % session['visit_count']] + <p>You visited this page %d times.</p>""" + % session["visit_count"] + ] def make_app(): return SessionMiddleware(application, MemorySessionStore()) -if __name__ == '__main__': - run_simple('localhost', 5000, make_app()) +if __name__ == "__main__": + run_simple("localhost", 5000, make_app()) diff --git a/examples/cookieauth.py b/examples/cookieauth.py index 64b5ae0d..ba23bda4 100644 --- a/examples/cookieauth.py +++ b/examples/cookieauth.py @@ -10,25 +10,25 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -from werkzeug.serving import run_simple -from werkzeug.utils import cached_property, escape, redirect -from werkzeug.wrappers import Request, Response from werkzeug.contrib.securecookie import SecureCookie +from werkzeug.serving import run_simple +from werkzeug.utils import cached_property +from werkzeug.utils import escape +from werkzeug.utils import redirect +from werkzeug.wrappers import Request +from werkzeug.wrappers import Response # don't use this key but a different one; you could just use # os.unrandom(20) to get something random. Changing this key # invalidates all sessions at once. -SECRET_KEY = '\xfa\xdd\xb8z\xae\xe0}4\x8b\xea' +SECRET_KEY = "\xfa\xdd\xb8z\xae\xe0}4\x8b\xea" # the cookie name for the session -COOKIE_NAME = 'session' +COOKIE_NAME = "session" # the users that may access -USERS = { - 'admin': 'default', - 'user1': 'default' -} +USERS = {"admin": "default", "user1": "default"} class AppRequest(Request): @@ -36,11 +36,11 @@ class AppRequest(Request): def logout(self): """Log the user out.""" - self.session.pop('username', None) + self.session.pop("username", None) def login(self, username): """Log the user in.""" - self.session['username'] = username + self.session["username"] = username @property def logged_in(self): @@ -50,7 +50,7 @@ class AppRequest(Request): @property def user(self): """The user that is logged in.""" - return self.session.get('username') + return self.session.get("username") @cached_property def session(self): @@ -61,16 +61,16 @@ class AppRequest(Request): def login_form(request): - error = '' - if request.method == 'POST': - username = request.form.get('username') - password = request.form.get('password') + error = "" + if request.method == "POST": + username = request.form.get("username") + password = request.form.get("password") if password and USERS.get(username) == password: request.login(username) - return redirect('') - error = '<p>Invalid credentials' - return Response(''' - <title>Login</title><h1>Login</h1> + return redirect("") + error = "<p>Invalid credentials" + return Response( + """<title>Login</title><h1>Login</h1> <p>Not logged in. %s <form action="" method="post"> @@ -79,23 +79,28 @@ def login_form(request): <input type="text" name="username" size=20> <input type="password" name="password", size=20> <input type="submit" value="Login"> - </form>''' % error, mimetype='text/html') + </form>""" + % error, + mimetype="text/html", + ) def index(request): - return Response(''' - <title>Logged in</title> + return Response( + """<title>Logged in</title> <h1>Logged in</h1> <p>Logged in as %s - <p><a href="/?do=logout">Logout</a> - ''' % escape(request.user), mimetype='text/html') + <p><a href="/?do=logout">Logout</a>""" + % escape(request.user), + mimetype="text/html", + ) @AppRequest.application def application(request): - if request.args.get('do') == 'logout': + if request.args.get("do") == "logout": request.logout() - response = redirect('.') + response = redirect(".") elif request.logged_in: response = index(request) else: @@ -104,5 +109,5 @@ def application(request): return response -if __name__ == '__main__': - run_simple('localhost', 4000, application) +if __name__ == "__main__": + run_simple("localhost", 4000, application) diff --git a/examples/coolmagic/__init__.py b/examples/coolmagic/__init__.py index 88d5db80..0526f1d6 100644 --- a/examples/coolmagic/__init__.py +++ b/examples/coolmagic/__init__.py @@ -8,4 +8,4 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -from coolmagic.application import make_app +from .application import make_app diff --git a/examples/coolmagic/application.py b/examples/coolmagic/application.py index 10ffd2a2..730f4d83 100644 --- a/examples/coolmagic/application.py +++ b/examples/coolmagic/application.py @@ -12,11 +12,18 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -from os import path, listdir -from coolmagic.utils import Request, local_manager +from os import listdir +from os import path + +from werkzeug.exceptions import HTTPException +from werkzeug.exceptions import NotFound from werkzeug.middleware.shared_data import SharedDataMiddleware -from werkzeug.routing import Map, Rule, RequestRedirect -from werkzeug.exceptions import HTTPException, NotFound +from werkzeug.routing import Map +from werkzeug.routing import RequestRedirect +from werkzeug.routing import Rule + +from .utils import local_manager +from .utils import Request class CoolMagicApplication(object): @@ -27,17 +34,17 @@ class CoolMagicApplication(object): def __init__(self, config): self.config = config - for fn in listdir(path.join(path.dirname(__file__), 'views')): - if fn.endswith('.py') and fn != '__init__.py': - __import__('coolmagic.views.' + fn[:-3]) + for fn in listdir(path.join(path.dirname(__file__), "views")): + if fn.endswith(".py") and fn != "__init__.py": + __import__("coolmagic.views." + fn[:-3]) from coolmagic.utils import exported_views + rules = [ # url for shared data. this will always be unmatched # because either the middleware or the webserver # handles that request first. - Rule('/public/<path:file>', - endpoint='shared_data') + Rule("/public/<path:file>", endpoint="shared_data") ] self.views = {} for endpoint, (func, rule, extra) in exported_views.items(): @@ -53,7 +60,7 @@ class CoolMagicApplication(object): endpoint, args = urls.match(req.path) resp = self.views[endpoint](**args) except NotFound: - resp = self.views['static.not_found']() + resp = self.views["static.not_found"]() except (HTTPException, RequestRedirect) as e: resp = e return resp(environ, start_response) @@ -68,9 +75,9 @@ def make_app(config=None): app = CoolMagicApplication(config) # static stuff - app = SharedDataMiddleware(app, { - '/public': path.join(path.dirname(__file__), 'public') - }) + app = SharedDataMiddleware( + app, {"/public": path.join(path.dirname(__file__), "public")} + ) # clean up locals app = local_manager.make_middleware(app) diff --git a/examples/coolmagic/helpers.py b/examples/coolmagic/helpers.py index 54638335..4cd4ac45 100644 --- a/examples/coolmagic/helpers.py +++ b/examples/coolmagic/helpers.py @@ -8,7 +8,7 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -from coolmagic.utils import ThreadedRequest +from .utils import ThreadedRequest #: a thread local proxy request object diff --git a/examples/coolmagic/utils.py b/examples/coolmagic/utils.py index b69f95a4..f4cf20d5 100644 --- a/examples/coolmagic/utils.py +++ b/examples/coolmagic/utils.py @@ -11,17 +11,21 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -from os.path import dirname, join -from jinja2 import Environment, FileSystemLoader -from werkzeug.local import Local, LocalManager -from werkzeug.wrappers import BaseRequest, BaseResponse +from os.path import dirname +from os.path import join + +from jinja2 import Environment +from jinja2 import FileSystemLoader +from werkzeug.local import Local +from werkzeug.local import LocalManager +from werkzeug.wrappers import BaseRequest +from werkzeug.wrappers import BaseResponse local = Local() local_manager = LocalManager([local]) template_env = Environment( - loader=FileSystemLoader(join(dirname(__file__), 'templates'), - use_memcache=False) + loader=FileSystemLoader(join(dirname(__file__), "templates"), use_memcache=False) ) exported_views = {} @@ -31,19 +35,23 @@ def export(string, template=None, **extra): Decorator for registering view functions and adding templates to it. """ + def wrapped(f): - endpoint = (f.__module__ + '.' + f.__name__)[16:] + endpoint = (f.__module__ + "." + f.__name__)[16:] if template is not None: old_f = f + def f(**kwargs): rv = old_f(**kwargs) if not isinstance(rv, Response): rv = TemplateResponse(template, **(rv or {})) return rv + f.__name__ = old_f.__name__ f.__doc__ = old_f.__doc__ exported_views[endpoint] = (f, string, extra) return f + return wrapped @@ -59,7 +67,8 @@ class Request(BaseRequest): The concrete request object used in the WSGI application. It has some helper functions that can be used to build URLs. """ - charset = 'utf-8' + + charset = "utf-8" def __init__(self, environ, url_adapter): BaseRequest.__init__(self, environ) @@ -74,9 +83,8 @@ class ThreadedRequest(object): """ def __getattr__(self, name): - if name == '__members__': - return [x for x in dir(local.request) if not - x.startswith('_')] + if name == "__members__": + return [x for x in dir(local.request) if not x.startswith("_")] return getattr(local.request, name) def __setattr__(self, name, value): @@ -87,8 +95,9 @@ class Response(BaseResponse): """ The concrete response object for the WSGI application. """ - charset = 'utf-8' - default_mimetype = 'text/html' + + charset = "utf-8" + default_mimetype = "text/html" class TemplateResponse(Response): @@ -98,9 +107,7 @@ class TemplateResponse(Response): def __init__(self, template_name, **values): from coolmagic import helpers - values.update( - request=local.request, - h=helpers - ) + + values.update(request=local.request, h=helpers) template = template_env.get_template(template_name) Response.__init__(self, template.render(values)) diff --git a/examples/coolmagic/views/static.py b/examples/coolmagic/views/static.py index edb09372..f4f95409 100644 --- a/examples/coolmagic/views/static.py +++ b/examples/coolmagic/views/static.py @@ -11,22 +11,22 @@ from coolmagic.utils import export -@export('/', template='static/index.html') +@export("/", template="static/index.html") def index(): pass -@export('/about', template='static/about.html') +@export("/about", template="static/about.html") def about(): pass -@export('/broken') +@export("/broken") def broken(): - raise RuntimeError('that\'s really broken') + raise RuntimeError("that's really broken") -@export(None, template='static/not_found.html') +@export(None, template="static/not_found.html") def not_found(): """ This function is always executed if an url does not diff --git a/examples/couchy/README b/examples/couchy/README index 1821ed11..24960448 100644 --- a/examples/couchy/README +++ b/examples/couchy/README @@ -5,4 +5,3 @@ Requirements : - jinja : http://jinja.pocoo.org - couchdb 0.72 & above : https://couchdb.apache.org/ - couchdb-python 0.3 & above : https://github.com/djc/couchdb-python - diff --git a/examples/couchy/application.py b/examples/couchy/application.py index f3a56a63..b958dcbc 100644 --- a/examples/couchy/application.py +++ b/examples/couchy/application.py @@ -1,27 +1,28 @@ from couchdb.client import Server -from couchy.utils import STATIC_PATH, local, local_manager, \ - url_map +from werkzeug.exceptions import HTTPException +from werkzeug.exceptions import NotFound from werkzeug.middleware.shared_data import SharedDataMiddleware from werkzeug.wrappers import Request from werkzeug.wsgi import ClosingIterator -from werkzeug.exceptions import HTTPException, NotFound -from couchy import views -from couchy.models import URL +from . import views +from .models import URL +from .utils import local +from .utils import local_manager +from .utils import STATIC_PATH +from .utils import url_map -class Couchy(object): +class Couchy(object): def __init__(self, db_uri): local.application = self server = Server(db_uri) try: - db = server.create('urls') - except: - db = server['urls'] - self.dispatch = SharedDataMiddleware(self.dispatch, { - '/static': STATIC_PATH - }) + db = server.create("urls") + except Exception: + db = server["urls"] + self.dispatch = SharedDataMiddleware(self.dispatch, {"/static": STATIC_PATH}) URL.db = db @@ -38,8 +39,9 @@ class Couchy(object): response.status_code = 404 except HTTPException as e: response = e - return ClosingIterator(response(environ, start_response), - [local_manager.cleanup]) + return ClosingIterator( + response(environ, start_response), [local_manager.cleanup] + ) def __call__(self, environ, start_response): return self.dispatch(environ, start_response) diff --git a/examples/couchy/models.py b/examples/couchy/models.py index 4621a744..a0b50ca1 100644 --- a/examples/couchy/models.py +++ b/examples/couchy/models.py @@ -1,6 +1,12 @@ from datetime import datetime -from couchdb.mapping import Document, TextField, BooleanField, DateTimeField -from couchy.utils import url_for, get_random_uid + +from couchdb.mapping import BooleanField +from couchdb.mapping import DateTimeField +from couchdb.mapping import Document +from couchdb.mapping import TextField + +from .utils import get_random_uid +from .utils import url_for class URL(Document): @@ -19,13 +25,15 @@ class URL(Document): return URL.db.query(code) def store(self): - if getattr(self._data, 'id', None) is None: + if getattr(self._data, "id", None) is None: new_id = self.shorty_id if self.shorty_id else None while 1: id = new_id if new_id else get_random_uid() try: - docid = URL.db.resource.put(content=self._data, path='/%s/' % str(id))['id'] - except: + docid = URL.db.resource.put( + content=self._data, path="/%s/" % str(id) + )["id"] + except Exception: continue if docid: break @@ -36,7 +44,7 @@ class URL(Document): @property def short_url(self): - return url_for('link', uid=self.id, _external=True) + return url_for("link", uid=self.id, _external=True) def __repr__(self): - return '<URL %r>' % self.id + return "<URL %r>" % self.id diff --git a/examples/couchy/utils.py b/examples/couchy/utils.py index 4fe666a3..571a7ed9 100644 --- a/examples/couchy/utils.py +++ b/examples/couchy/utils.py @@ -1,49 +1,62 @@ from os import path -from random import sample, randrange -from jinja2 import Environment, FileSystemLoader -from werkzeug.local import Local, LocalManager +from random import randrange +from random import sample + +from jinja2 import Environment +from jinja2 import FileSystemLoader +from werkzeug.local import Local +from werkzeug.local import LocalManager +from werkzeug.routing import Map +from werkzeug.routing import Rule from werkzeug.urls import url_parse from werkzeug.utils import cached_property from werkzeug.wrappers import Response -from werkzeug.routing import Map, Rule -TEMPLATE_PATH = path.join(path.dirname(__file__), 'templates') -STATIC_PATH = path.join(path.dirname(__file__), 'static') -ALLOWED_SCHEMES = frozenset(['http', 'https', 'ftp', 'ftps']) -URL_CHARS = 'abcdefghijkmpqrstuvwxyzABCDEFGHIJKLMNPQRST23456789' +TEMPLATE_PATH = path.join(path.dirname(__file__), "templates") +STATIC_PATH = path.join(path.dirname(__file__), "static") +ALLOWED_SCHEMES = frozenset(["http", "https", "ftp", "ftps"]) +URL_CHARS = "abcdefghijkmpqrstuvwxyzABCDEFGHIJKLMNPQRST23456789" local = Local() local_manager = LocalManager([local]) -application = local('application') +application = local("application") -url_map = Map([Rule('/static/<file>', endpoint='static', build_only=True)]) +url_map = Map([Rule("/static/<file>", endpoint="static", build_only=True)]) jinja_env = Environment(loader=FileSystemLoader(TEMPLATE_PATH)) def expose(rule, **kw): def decorate(f): - kw['endpoint'] = f.__name__ + kw["endpoint"] = f.__name__ url_map.add(Rule(rule, **kw)) return f + return decorate + def url_for(endpoint, _external=False, **values): return local.url_adapter.build(endpoint, values, force_external=_external) -jinja_env.globals['url_for'] = url_for + + +jinja_env.globals["url_for"] = url_for + def render_template(template, **context): - return Response(jinja_env.get_template(template).render(**context), - mimetype='text/html') + return Response( + jinja_env.get_template(template).render(**context), mimetype="text/html" + ) + def validate_url(url): return url_parse(url)[0] in ALLOWED_SCHEMES + def get_random_uid(): - return ''.join(sample(URL_CHARS, randrange(3, 9))) + return "".join(sample(URL_CHARS, randrange(3, 9))) -class Pagination(object): +class Pagination(object): def __init__(self, results, per_page, page, endpoint): self.results = results self.per_page = per_page @@ -56,7 +69,11 @@ class Pagination(object): @cached_property def entries(self): - return self.results[((self.page - 1) * self.per_page):(((self.page - 1) * self.per_page)+self.per_page)] + return self.results[ + ((self.page - 1) * self.per_page) : ( + ((self.page - 1) * self.per_page) + self.per_page + ) + ] has_previous = property(lambda self: self.page > 1) has_next = property(lambda self: self.page < self.pages) diff --git a/examples/couchy/views.py b/examples/couchy/views.py index 39c8ea29..c1547e7d 100644 --- a/examples/couchy/views.py +++ b/examples/couchy/views.py @@ -1,61 +1,73 @@ -from werkzeug.utils import redirect from werkzeug.exceptions import NotFound -from couchy.utils import render_template, expose, \ - validate_url, url_for, Pagination -from couchy.models import URL +from werkzeug.utils import redirect +from .models import URL +from .utils import expose +from .utils import Pagination +from .utils import render_template +from .utils import url_for +from .utils import validate_url -@expose('/') + +@expose("/") def new(request): - error = url = '' - if request.method == 'POST': - url = request.form.get('url') - alias = request.form.get('alias') + error = url = "" + if request.method == "POST": + url = request.form.get("url") + alias = request.form.get("alias") if not validate_url(url): error = "I'm sorry but you cannot shorten this URL." elif alias: if len(alias) > 140: - error = 'Your alias is too long' - elif '/' in alias: - error = 'Your alias might not include a slash' + error = "Your alias is too long" + elif "/" in alias: + error = "Your alias might not include a slash" elif URL.load(alias): - error = 'The alias you have requested exists already' + error = "The alias you have requested exists already" if not error: - url = URL(target=url, public='private' not in request.form, shorty_id=alias if alias else None) + url = URL( + target=url, + public="private" not in request.form, + shorty_id=alias if alias else None, + ) url.store() uid = url.id - return redirect(url_for('display', uid=uid)) - return render_template('new.html', error=error, url=url) + return redirect(url_for("display", uid=uid)) + return render_template("new.html", error=error, url=url) + -@expose('/display/<uid>') +@expose("/display/<uid>") def display(request, uid): url = URL.load(uid) if not url: raise NotFound() - return render_template('display.html', url=url) + return render_template("display.html", url=url) -@expose('/u/<uid>') + +@expose("/u/<uid>") def link(request, uid): url = URL.load(uid) if not url: raise NotFound() return redirect(url.target, 301) -@expose('/list/', defaults={'page': 1}) -@expose('/list/<int:page>') + +@expose("/list/", defaults={"page": 1}) +@expose("/list/<int:page>") def list(request, page): def wrap(doc): data = doc.value - data['_id'] = doc.id + data["_id"] = doc.id return URL.wrap(data) - code = '''function(doc) { if (doc.public){ map([doc._id], doc); }}''' + code = """function(doc) { if (doc.public){ map([doc._id], doc); }}""" docResults = URL.query(code) results = [wrap(doc) for doc in docResults] - pagination = Pagination(results, 1, page, 'list') + pagination = Pagination(results, 1, page, "list") if pagination.page > 1 and not pagination.entries: raise NotFound() - return render_template('list.html', pagination=pagination) + return render_template("list.html", pagination=pagination) + def not_found(request): - return render_template('not_found.html') + return render_template("not_found.html") diff --git a/examples/cupoftee/__init__.py b/examples/cupoftee/__init__.py index f4a9d04d..184c5d0d 100644 --- a/examples/cupoftee/__init__.py +++ b/examples/cupoftee/__init__.py @@ -8,4 +8,4 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -from cupoftee.application import make_app +from .application import make_app diff --git a/examples/cupoftee/application.py b/examples/cupoftee/application.py index 140b6b74..540e3f59 100644 --- a/examples/cupoftee/application.py +++ b/examples/cupoftee/application.py @@ -9,44 +9,58 @@ :license: BSD-3-Clause """ import time - -from jinja2 import Environment, PackageLoader from os import path from threading import Thread -from cupoftee.db import Database -from cupoftee.network import ServerBrowser +from jinja2 import Environment +from jinja2 import PackageLoader +from werkzeug.exceptions import HTTPException +from werkzeug.exceptions import NotFound from werkzeug.middleware.shared_data import SharedDataMiddleware -from werkzeug.wrappers import Request, Response -from werkzeug.exceptions import HTTPException, NotFound -from werkzeug.routing import Map, Rule +from werkzeug.routing import Map +from werkzeug.routing import Rule +from werkzeug.wrappers import Request +from werkzeug.wrappers import Response + +from .db import Database +from .network import ServerBrowser -templates = path.join(path.dirname(__file__), 'templates') +templates = path.join(path.dirname(__file__), "templates") pages = {} -url_map = Map([Rule('/shared/<file>', endpoint='shared')]) +url_map = Map([Rule("/shared/<file>", endpoint="shared")]) def make_app(database, interval=120): - return SharedDataMiddleware(Cup(database, interval), { - '/shared': path.join(path.dirname(__file__), 'shared') - }) + return SharedDataMiddleware( + Cup(database, interval), + {"/shared": path.join(path.dirname(__file__), "shared")}, + ) class PageMeta(type): - def __init__(cls, name, bases, d): type.__init__(cls, name, bases, d) - if d.get('url_rule') is not None: + if d.get("url_rule") is not None: pages[cls.identifier] = cls - url_map.add(Rule(cls.url_rule, endpoint=cls.identifier, - **cls.url_arguments)) + url_map.add( + Rule(cls.url_rule, endpoint=cls.identifier, **cls.url_arguments) + ) identifier = property(lambda self: self.__name__.lower()) -class Page(object): - __metaclass__ = PageMeta +def _with_metaclass(meta, *bases): + """Create a base class with a metaclass.""" + + class metaclass(type): + def __new__(metacls, name, this_bases, d): + return meta(name, bases, d) + + return type.__new__(metaclass, "temporary_class", (), {}) + + +class Page(_with_metaclass(PageMeta, object)): url_arguments = {} def __init__(self, cup, request, url_adapter): @@ -62,17 +76,16 @@ class Page(object): def render_template(self, template=None): if template is None: - template = self.__class__.identifier + '.html' + template = self.__class__.identifier + ".html" context = dict(self.__dict__) context.update(url_for=self.url_for, self=self) return self.cup.render_template(template, context) def get_response(self): - return Response(self.render_template(), mimetype='text/html') + return Response(self.render_template(), mimetype="text/html") class Cup(object): - def __init__(self, database, interval=120): self.jinja_env = Environment(loader=PackageLoader("cupoftee"), autoescape=True) self.interval = interval @@ -111,4 +124,5 @@ class Cup(object): template = self.jinja_env.get_template(name) return template.render(context) + from cupoftee.pages import MissingPage diff --git a/examples/cupoftee/db.py b/examples/cupoftee/db.py index 9db97c3d..7f041220 100644 --- a/examples/cupoftee/db.py +++ b/examples/cupoftee/db.py @@ -9,8 +9,9 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ +from pickle import dumps +from pickle import loads from threading import Lock -from pickle import dumps, loads try: import dbm @@ -19,10 +20,9 @@ except ImportError: class Database(object): - def __init__(self, filename): self.filename = filename - self._fs = dbm.open(filename, 'cf') + self._fs = dbm.open(filename, "cf") self._local = {} self._lock = Lock() @@ -43,7 +43,7 @@ class Database(object): def __delitem__(self, key, value): with self._lock: self._local.pop(key, None) - if self._fs.has_key(key): + if key in self._fs: del self._fs[key] def __del__(self): @@ -75,5 +75,5 @@ class Database(object): try: self.sync() self._fs.close() - except: + except Exception: pass diff --git a/examples/cupoftee/network.py b/examples/cupoftee/network.py index 0e472c1b..74c775aa 100644 --- a/examples/cupoftee/network.py +++ b/examples/cupoftee/network.py @@ -9,9 +9,10 @@ :license: BSD-3-Clause """ import socket -from math import log from datetime import datetime -from cupoftee.utils import unicodecmp +from math import log + +from .utils import unicodecmp class ServerError(Exception): @@ -31,15 +32,14 @@ class Syncable(object): class ServerBrowser(Syncable): - def __init__(self, cup): self.cup = cup - self.servers = cup.db.setdefault('servers', dict) + self.servers = cup.db.setdefault("servers", dict) def _sync(self): to_delete = set(self.servers) for x in range(1, 17): - addr = ('master%d.teeworlds.com' % x, 8300) + addr = ("master%d.teeworlds.com" % x, 8300) print(addr) try: self._sync_master(addr, to_delete) @@ -48,20 +48,22 @@ class ServerBrowser(Syncable): for server_id in to_delete: self.servers.pop(server_id, None) if not self.servers: - raise IOError('no servers found') + raise IOError("no servers found") self.cup.db.sync() def _sync_master(self, addr, to_delete): s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.settimeout(5) - s.sendto(b'\x20\x00\x00\x00\x00\x48\xff\xff\xff\xffreqt', addr) + s.sendto(b"\x20\x00\x00\x00\x00\x48\xff\xff\xff\xffreqt", addr) data = s.recvfrom(1024)[0][14:] s.close() for n in range(0, len(data) // 6): - addr = ('.'.join(map(str, map(ord, data[n * 6:n * 6 + 4]))), - ord(data[n * 6 + 5]) * 256 + ord(data[n * 6 + 4])) - server_id = '%s:%d' % addr + addr = ( + ".".join(map(str, map(ord, data[n * 6 : n * 6 + 4]))), + ord(data[n * 6 + 5]) * 256 + ord(data[n * 6 + 4]), + ) + server_id = "%s:%d" % addr if server_id in self.servers: if not self.servers[server_id].sync(): continue @@ -74,31 +76,31 @@ class ServerBrowser(Syncable): class Server(Syncable): - def __init__(self, addr, server_id): self.addr = addr self.id = server_id self.players = [] if not self.sync(): - raise ServerError('server not responding in time') + raise ServerError("server not responding in time") def _sync(self): s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.settimeout(1) - s.sendto(b'\xff\xff\xff\xff\xff\xff\xff\xff\xff\xffgief', self.addr) - bits = s.recvfrom(1024)[0][14:].split(b'\x00') + s.sendto(b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xffgief", self.addr) + bits = s.recvfrom(1024)[0][14:].split(b"\x00") s.close() self.version, server_name, map_name = bits[:3] - self.name = server_name.decode('latin1') - self.map = map_name.decode('latin1') + self.name = server_name.decode("latin1") + self.map = map_name.decode("latin1") self.gametype = bits[3] - self.flags, self.progression, player_count, \ - self.max_players = map(int, bits[4:8]) + self.flags, self.progression, player_count, self.max_players = map( + int, bits[4:8] + ) # sync the player stats players = dict((p.name, p) for p in self.players) for i in range(player_count): - name = bits[8 + i * 2].decode('latin1') + name = bits[8 + i * 2].decode("latin1") score = int(bits[9 + i * 2]) # update existing player @@ -112,7 +114,7 @@ class Server(Syncable): for player in players.values(): try: self.players.remove(player) - except: + except Exception: pass # sort the player list and count them @@ -124,7 +126,6 @@ class Server(Syncable): class Player(object): - def __init__(self, server, name, score): self.server = server self.name = name diff --git a/examples/cupoftee/pages.py b/examples/cupoftee/pages.py index 7d799c25..c1a823b7 100644 --- a/examples/cupoftee/pages.py +++ b/examples/cupoftee/pages.py @@ -10,41 +10,42 @@ """ from functools import reduce -from werkzeug.utils import redirect from werkzeug.exceptions import NotFound -from cupoftee.application import Page -from cupoftee.utils import unicodecmp +from werkzeug.utils import redirect + +from .application import Page +from .utils import unicodecmp class ServerList(Page): - url_rule = '/' + url_rule = "/" def order_link(self, name, title): - cls = '' - link = '?order_by=' + name + cls = "" + link = "?order_by=" + name desc = False if name == self.order_by: desc = not self.order_desc - cls = ' class="%s"' % ('down' if desc else 'up') + cls = ' class="%s"' % ("down" if desc else "up") if desc: - link += '&dir=desc' + link += "&dir=desc" return '<a href="%s"%s>%s</a>' % (link, cls, title) def process(self): - self.order_by = self.request.args.get('order_by') or 'name' + self.order_by = self.request.args.get("order_by") or "name" sort_func = { - 'name': lambda x: x, - 'map': lambda x: x.map, - 'gametype': lambda x: x.gametype, - 'players': lambda x: x.player_count, - 'progression': lambda x: x.progression, + "name": lambda x: x, + "map": lambda x: x.map, + "gametype": lambda x: x.gametype, + "players": lambda x: x.player_count, + "progression": lambda x: x.progression, }.get(self.order_by) if sort_func is None: - return redirect(self.url_for('serverlist')) + return redirect(self.url_for("serverlist")) self.servers = self.cup.master.servers.values() self.servers.sort(key=sort_func) - if self.request.args.get('dir') == 'desc': + if self.request.args.get("dir") == "desc": self.servers.reverse() self.order_desc = True else: @@ -55,7 +56,7 @@ class ServerList(Page): class Server(Page): - url_rule = '/server/<id>' + url_rule = "/server/<id>" def process(self, id): try: @@ -65,10 +66,10 @@ class Server(Page): class Search(Page): - url_rule = '/search' + url_rule = "/search" def process(self): - self.user = self.request.args.get('user') + self.user = self.request.args.get("user") if self.user: self.results = [] for server in self.cup.master.servers.values(): @@ -78,7 +79,6 @@ class Search(Page): class MissingPage(Page): - def get_response(self): response = super(MissingPage, self).get_response() response.status_code = 404 diff --git a/examples/cupoftee/utils.py b/examples/cupoftee/utils.py index f13e2f3d..91717e85 100644 --- a/examples/cupoftee/utils.py +++ b/examples/cupoftee/utils.py @@ -11,7 +11,7 @@ import re -_sort_re = re.compile(r'\w+', re.UNICODE) +_sort_re = re.compile(r"\w+", re.UNICODE) def unicodecmp(a, b): diff --git a/examples/httpbasicauth.py b/examples/httpbasicauth.py index 5a774f3f..21d39300 100644 --- a/examples/httpbasicauth.py +++ b/examples/httpbasicauth.py @@ -10,12 +10,12 @@ :license: BSD-3-Clause """ from werkzeug.serving import run_simple -from werkzeug.wrappers import Request, Response +from werkzeug.wrappers import Request +from werkzeug.wrappers import Response class Application(object): - - def __init__(self, users, realm='login required'): + def __init__(self, users, realm="login required"): self.users = users self.realm = realm @@ -23,12 +23,15 @@ class Application(object): return username in self.users and self.users[username] == password def auth_required(self, request): - return Response('Could not verify your access level for that URL.\n' - 'You have to login with proper credentials', 401, - {'WWW-Authenticate': 'Basic realm="%s"' % self.realm}) + return Response( + "Could not verify your access level for that URL.\n" + "You have to login with proper credentials", + 401, + {"WWW-Authenticate": 'Basic realm="%s"' % self.realm}, + ) def dispatch_request(self, request): - return Response('Logged in as %s' % request.authorization.username) + return Response("Logged in as %s" % request.authorization.username) def __call__(self, environ, start_response): request = Request(environ) @@ -40,6 +43,6 @@ class Application(object): return response(environ, start_response) -if __name__ == '__main__': - application = Application({'user1': 'password', 'user2': 'password'}) - run_simple('localhost', 5000, application) +if __name__ == "__main__": + application = Application({"user1": "password", "user2": "password"}) + run_simple("localhost", 5000, application) diff --git a/examples/i18nurls/__init__.py b/examples/i18nurls/__init__.py index 393d270c..f5f5c6ed 100644 --- a/examples/i18nurls/__init__.py +++ b/examples/i18nurls/__init__.py @@ -1 +1 @@ -from i18nurls.application import Application as make_app +from .application import Application as make_app diff --git a/examples/i18nurls/application.py b/examples/i18nurls/application.py index 54f12353..103f6001 100644 --- a/examples/i18nurls/application.py +++ b/examples/i18nurls/application.py @@ -1,33 +1,39 @@ -from jinja2 import Environment, PackageLoader from os import path -from werkzeug.wrappers import Request as _Request, BaseResponse + +from jinja2 import Environment +from jinja2 import PackageLoader +from werkzeug.exceptions import HTTPException +from werkzeug.exceptions import NotFound from werkzeug.routing import RequestRedirect -from werkzeug.exceptions import HTTPException, NotFound -from i18nurls.urls import map +from werkzeug.wrappers import BaseResponse +from werkzeug.wrappers import Request as _Request + +from .urls import map -TEMPLATES = path.join(path.dirname(__file__), 'templates') +TEMPLATES = path.join(path.dirname(__file__), "templates") views = {} def expose(name): """Register the function as view.""" + def wrapped(f): views[name] = f return f + return wrapped class Request(_Request): - def __init__(self, environ, urls): super(Request, self).__init__(environ) self.urls = urls self.matched_url = None def url_for(self, endpoint, **args): - if not 'lang_code' in args: - args['lang_code'] = self.language - if endpoint == 'this': + if "lang_code" not in args: + args["lang_code"] = self.language + if endpoint == "this": endpoint = self.matched_url[0] tmp = self.matched_url[1].copy() tmp.update(args) @@ -45,12 +51,12 @@ class TemplateResponse(Response): def __init__(self, template_name, **values): self.template_name = template_name self.template_values = values - Response.__init__(self, mimetype='text/html') + Response.__init__(self, mimetype="text/html") def __call__(self, environ, start_response): - req = environ['werkzeug.request'] + req = environ["werkzeug.request"] values = self.template_values.copy() - values['req'] = req + values["req"] = req self.data = self.render_template(self.template_name, values) return super(TemplateResponse, self).__call__(environ, start_response) @@ -60,9 +66,9 @@ class TemplateResponse(Response): class Application(object): - def __init__(self): from i18nurls import views + self.not_found = views.page_not_found def __call__(self, environ, start_response): @@ -71,14 +77,14 @@ class Application(object): try: endpoint, args = urls.match(req.path) req.matched_url = (endpoint, args) - if endpoint == '#language_select': + if endpoint == "#language_select": lng = req.accept_languages.best - lng = lng and lng.split('-')[0].lower() or 'en' - index_url = urls.build('index', {'lang_code': lng}) - resp = Response('Moved to %s' % index_url, status=302) - resp.headers['Location'] = index_url + lng = lng and lng.split("-")[0].lower() or "en" + index_url = urls.build("index", {"lang_code": lng}) + resp = Response("Moved to %s" % index_url, status=302) + resp.headers["Location"] = index_url else: - req.language = args.pop('lang_code', None) + req.language = args.pop("lang_code", None) resp = views[endpoint](req, **args) except NotFound: resp = self.not_found(req) diff --git a/examples/i18nurls/urls.py b/examples/i18nurls/urls.py index 57ff68a2..3dd54a00 100644 --- a/examples/i18nurls/urls.py +++ b/examples/i18nurls/urls.py @@ -1,11 +1,18 @@ -from werkzeug.routing import Map, Rule, Submount +from werkzeug.routing import Map +from werkzeug.routing import Rule +from werkzeug.routing import Submount -map = Map([ - Rule('/', endpoint='#language_select'), - Submount('/<string(length=2):lang_code>', [ - Rule('/', endpoint='index'), - Rule('/about', endpoint='about'), - Rule('/blog/', endpoint='blog/index'), - Rule('/blog/<int:post_id>', endpoint='blog/show') - ]) -]) +map = Map( + [ + Rule("/", endpoint="#language_select"), + Submount( + "/<string(length=2):lang_code>", + [ + Rule("/", endpoint="index"), + Rule("/about", endpoint="about"), + Rule("/blog/", endpoint="blog/index"), + Rule("/blog/<int:post_id>", endpoint="blog/show"), + ], + ), + ] +) diff --git a/examples/i18nurls/views.py b/examples/i18nurls/views.py index 7b8bdb70..6c0ca895 100644 --- a/examples/i18nurls/views.py +++ b/examples/i18nurls/views.py @@ -1,22 +1,29 @@ -from i18nurls.application import TemplateResponse, Response, expose +from .application import expose +from .application import Response +from .application import TemplateResponse -@expose('index') +@expose("index") def index(req): - return TemplateResponse('index.html', title='Index') + return TemplateResponse("index.html", title="Index") -@expose('about') + +@expose("about") def about(req): - return TemplateResponse('about.html', title='About') + return TemplateResponse("about.html", title="About") + -@expose('blog/index') +@expose("blog/index") def blog_index(req): - return TemplateResponse('blog.html', title='Blog Index', mode='index') + return TemplateResponse("blog.html", title="Blog Index", mode="index") -@expose('blog/show') + +@expose("blog/show") def blog_show(req, post_id): - return TemplateResponse('blog.html', title='Blog Post #%d' % post_id, - post_id=post_id, mode='show') + return TemplateResponse( + "blog.html", title="Blog Post #%d" % post_id, post_id=post_id, mode="show" + ) + def page_not_found(req): - return Response('<h1>Page Not Found</h1>', mimetype='text/html') + return Response("<h1>Page Not Found</h1>", mimetype="text/html") diff --git a/examples/manage-coolmagic.py b/examples/manage-coolmagic.py index 8059003f..f6abe80e 100755 --- a/examples/manage-coolmagic.py +++ b/examples/manage-coolmagic.py @@ -10,9 +10,10 @@ :license: BSD-3-Clause """ import click -from coolmagic import make_app from werkzeug.serving import run_simple +from coolmagic import make_app + @click.group() def cli(): @@ -20,36 +21,45 @@ def cli(): @cli.command() -@click.option('-h', '--hostname', type=str, default='localhost', help="localhost") -@click.option('-p', '--port', type=int, default=5000, help="5000") -@click.option('--no-reloader', is_flag=True, default=False) -@click.option('--debugger', is_flag=True) -@click.option('--no-evalex', is_flag=True, default=False) -@click.option('--threaded', is_flag=True) -@click.option('--processes', type=int, default=1, help="1") +@click.option("-h", "--hostname", type=str, default="localhost", help="localhost") +@click.option("-p", "--port", type=int, default=5000, help="5000") +@click.option("--no-reloader", is_flag=True, default=False) +@click.option("--debugger", is_flag=True) +@click.option("--no-evalex", is_flag=True, default=False) +@click.option("--threaded", is_flag=True) +@click.option("--processes", type=int, default=1, help="1") def runserver(hostname, port, no_reloader, debugger, no_evalex, threaded, processes): """Start a new development server.""" app = make_app() reloader = not no_reloader evalex = not no_evalex - run_simple(hostname, port, app, - use_reloader=reloader, use_debugger=debugger, - use_evalex=evalex, threaded=threaded, processes=processes) + run_simple( + hostname, + port, + app, + use_reloader=reloader, + use_debugger=debugger, + use_evalex=evalex, + threaded=threaded, + processes=processes, + ) @cli.command() -@click.option('--no-ipython', is_flag=True, default=False) +@click.option("--no-ipython", is_flag=True, default=False) def shell(no_ipython): """Start a new interactive python session.""" - banner = 'Interactive Werkzeug Shell' + banner = "Interactive Werkzeug Shell" namespace = dict() if not no_ipython: try: try: from IPython.frontend.terminal.embed import InteractiveShellEmbed + sh = InteractiveShellEmbed.instance(banner1=banner) except ImportError: from IPython.Shell import IPShellEmbed + sh = IPShellEmbed(banner=banner) except ImportError: pass @@ -57,7 +67,9 @@ def shell(no_ipython): sh(local_ns=namespace) return from code import interact + interact(banner, local=namespace) -if __name__ == '__main__': + +if __name__ == "__main__": cli() diff --git a/examples/manage-couchy.py b/examples/manage-couchy.py index 1bbf4a67..8a00a400 100755 --- a/examples/manage-couchy.py +++ b/examples/manage-couchy.py @@ -5,17 +5,15 @@ from werkzeug.serving import run_simple def make_app(): from couchy.application import Couchy - return Couchy('http://localhost:5984') + + return Couchy("http://localhost:5984") def make_shell(): from couchy import models, utils + application = make_app() - return { - "application": application, - "models": models, - "utils": utils, - } + return {"application": application, "models": models, "utils": utils} @click.group() @@ -26,40 +24,50 @@ def cli(): @cli.command() def initdb(): from couchy.application import Couchy - Couchy('http://localhost:5984').init_database() + + Couchy("http://localhost:5984").init_database() @cli.command() -@click.option('-h', '--hostname', type=str, default='localhost', help="localhost") -@click.option('-p', '--port', type=int, default=5000, help="5000") -@click.option('--no-reloader', is_flag=True, default=False) -@click.option('--debugger', is_flag=True) -@click.option('--no-evalex', is_flag=True, default=False) -@click.option('--threaded', is_flag=True) -@click.option('--processes', type=int, default=1, help="1") +@click.option("-h", "--hostname", type=str, default="localhost", help="localhost") +@click.option("-p", "--port", type=int, default=5000, help="5000") +@click.option("--no-reloader", is_flag=True, default=False) +@click.option("--debugger", is_flag=True) +@click.option("--no-evalex", is_flag=True, default=False) +@click.option("--threaded", is_flag=True) +@click.option("--processes", type=int, default=1, help="1") def runserver(hostname, port, no_reloader, debugger, no_evalex, threaded, processes): """Start a new development server.""" app = make_app() reloader = not no_reloader evalex = not no_evalex - run_simple(hostname, port, app, - use_reloader=reloader, use_debugger=debugger, - use_evalex=evalex, threaded=threaded, processes=processes) + run_simple( + hostname, + port, + app, + use_reloader=reloader, + use_debugger=debugger, + use_evalex=evalex, + threaded=threaded, + processes=processes, + ) @cli.command() -@click.option('--no-ipython', is_flag=True, default=False) +@click.option("--no-ipython", is_flag=True, default=False) def shell(no_ipython): """Start a new interactive python session.""" - banner = 'Interactive Werkzeug Shell' + banner = "Interactive Werkzeug Shell" namespace = make_shell() if not no_ipython: try: try: from IPython.frontend.terminal.embed import InteractiveShellEmbed + sh = InteractiveShellEmbed.instance(banner1=banner) except ImportError: from IPython.Shell import IPShellEmbed + sh = IPShellEmbed(banner=banner) except ImportError: pass @@ -67,7 +75,9 @@ def shell(no_ipython): sh(local_ns=namespace) return from code import interact + interact(banner, local=namespace) -if __name__ == '__main__': + +if __name__ == "__main__": cli() diff --git a/examples/manage-cupoftee.py b/examples/manage-cupoftee.py index e905f06b..a787f324 100755 --- a/examples/manage-cupoftee.py +++ b/examples/manage-cupoftee.py @@ -14,7 +14,8 @@ from werkzeug.serving import run_simple def make_app(): from cupoftee import make_app - return make_app('/tmp/cupoftee.db') + + return make_app("/tmp/cupoftee.db") @click.group() @@ -23,20 +24,27 @@ def cli(): @cli.command() -@click.option('-h', '--hostname', type=str, default='localhost', help="localhost") -@click.option('-p', '--port', type=int, default=5000, help="5000") -@click.option('--reloader', is_flag=True, default=False) -@click.option('--debugger', is_flag=True) -@click.option('--evalex', is_flag=True, default=False) -@click.option('--threaded', is_flag=True) -@click.option('--processes', type=int, default=1, help="1") +@click.option("-h", "--hostname", type=str, default="localhost", help="localhost") +@click.option("-p", "--port", type=int, default=5000, help="5000") +@click.option("--reloader", is_flag=True, default=False) +@click.option("--debugger", is_flag=True) +@click.option("--evalex", is_flag=True, default=False) +@click.option("--threaded", is_flag=True) +@click.option("--processes", type=int, default=1, help="1") def runserver(hostname, port, reloader, debugger, evalex, threaded, processes): """Start a new development server.""" app = make_app() - run_simple(hostname, port, app, - use_reloader=reloader, use_debugger=debugger, - use_evalex=evalex, threaded=threaded, processes=processes) - - -if __name__ == '__main__': + run_simple( + hostname, + port, + app, + use_reloader=reloader, + use_debugger=debugger, + use_evalex=evalex, + threaded=threaded, + processes=processes, + ) + + +if __name__ == "__main__": cli() diff --git a/examples/manage-i18nurls.py b/examples/manage-i18nurls.py index 04a5068e..a2b746a2 100755 --- a/examples/manage-i18nurls.py +++ b/examples/manage-i18nurls.py @@ -10,9 +10,10 @@ :license: BSD-3-Clause """ import click -from i18nurls import make_app from werkzeug.serving import run_simple +from i18nurls import make_app + @click.group() def cli(): @@ -20,36 +21,45 @@ def cli(): @cli.command() -@click.option('-h', '--hostname', type=str, default='localhost', help="localhost") -@click.option('-p', '--port', type=int, default=5000, help="5000") -@click.option('--no-reloader', is_flag=True, default=False) -@click.option('--debugger', is_flag=True) -@click.option('--no-evalex', is_flag=True, default=False) -@click.option('--threaded', is_flag=True) -@click.option('--processes', type=int, default=1, help="1") +@click.option("-h", "--hostname", type=str, default="localhost", help="localhost") +@click.option("-p", "--port", type=int, default=5000, help="5000") +@click.option("--no-reloader", is_flag=True, default=False) +@click.option("--debugger", is_flag=True) +@click.option("--no-evalex", is_flag=True, default=False) +@click.option("--threaded", is_flag=True) +@click.option("--processes", type=int, default=1, help="1") def runserver(hostname, port, no_reloader, debugger, no_evalex, threaded, processes): """Start a new development server.""" app = make_app() reloader = not no_reloader evalex = not no_evalex - run_simple(hostname, port, app, - use_reloader=reloader, use_debugger=debugger, - use_evalex=evalex, threaded=threaded, processes=processes) + run_simple( + hostname, + port, + app, + use_reloader=reloader, + use_debugger=debugger, + use_evalex=evalex, + threaded=threaded, + processes=processes, + ) @cli.command() -@click.option('--no-ipython', is_flag=True, default=False) +@click.option("--no-ipython", is_flag=True, default=False) def shell(no_ipython): """Start a new interactive python session.""" - banner = 'Interactive Werkzeug Shell' + banner = "Interactive Werkzeug Shell" namespace = dict() if not no_ipython: try: try: from IPython.frontend.terminal.embed import InteractiveShellEmbed + sh = InteractiveShellEmbed.instance(banner1=banner) except ImportError: from IPython.Shell import IPShellEmbed + sh = IPShellEmbed(banner=banner) except ImportError: pass @@ -57,7 +67,9 @@ def shell(no_ipython): sh(local_ns=namespace) return from code import interact + interact(banner, local=namespace) -if __name__ == '__main__': + +if __name__ == "__main__": cli() diff --git a/examples/manage-plnt.py b/examples/manage-plnt.py index 8ce5b229..80a03f78 100755 --- a/examples/manage-plnt.py +++ b/examples/manage-plnt.py @@ -9,16 +9,18 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import click import os + +import click from werkzeug.serving import run_simple def make_app(): """Helper function that creates a plnt app.""" from plnt import Plnt - database_uri = os.environ.get('PLNT_DATABASE_URI') - app = Plnt(database_uri or 'sqlite:////tmp/plnt.db') + + database_uri = os.environ.get("PLNT_DATABASE_URI") + app = Plnt(database_uri or "sqlite:////tmp/plnt.db") app.bind_to_context() return app @@ -32,62 +34,88 @@ def cli(): def initdb(): """Initialize the database""" from plnt.database import Blog, session + make_app().init_database() # and now fill in some python blogs everybody should read (shamelessly # added my own blog too) blogs = [ - Blog('Armin Ronacher', 'http://lucumr.pocoo.org/', - 'http://lucumr.pocoo.org/cogitations/feed/'), - Blog('Georg Brandl', 'http://pyside.blogspot.com/', - 'http://pyside.blogspot.com/feeds/posts/default'), - Blog('Ian Bicking', 'http://blog.ianbicking.org/', - 'http://blog.ianbicking.org/feed/'), - Blog('Amir Salihefendic', 'http://amix.dk/', - 'http://feeds.feedburner.com/amixdk'), - Blog('Christopher Lenz', 'http://www.cmlenz.net/blog/', - 'http://www.cmlenz.net/blog/atom.xml'), - Blog('Frederick Lundh', 'http://online.effbot.org/', - 'http://online.effbot.org/rss.xml') + Blog( + "Armin Ronacher", + "http://lucumr.pocoo.org/", + "http://lucumr.pocoo.org/cogitations/feed/", + ), + Blog( + "Georg Brandl", + "http://pyside.blogspot.com/", + "http://pyside.blogspot.com/feeds/posts/default", + ), + Blog( + "Ian Bicking", + "http://blog.ianbicking.org/", + "http://blog.ianbicking.org/feed/", + ), + Blog( + "Amir Salihefendic", "http://amix.dk/", "http://feeds.feedburner.com/amixdk" + ), + Blog( + "Christopher Lenz", + "http://www.cmlenz.net/blog/", + "http://www.cmlenz.net/blog/atom.xml", + ), + Blog( + "Frederick Lundh", + "http://online.effbot.org/", + "http://online.effbot.org/rss.xml", + ), ] # okay. got tired here. if someone feels that he is missing, drop me # a line ;-) for blog in blogs: session.add(blog) session.commit() - click.echo('Initialized database, now run manage-plnt.py sync to get the posts') + click.echo("Initialized database, now run manage-plnt.py sync to get the posts") @cli.command() -@click.option('-h', '--hostname', type=str, default='localhost', help="localhost") -@click.option('-p', '--port', type=int, default=5000, help="5000") -@click.option('--no-reloader', is_flag=True, default=False) -@click.option('--debugger', is_flag=True) -@click.option('--no-evalex', is_flag=True, default=False) -@click.option('--threaded', is_flag=True) -@click.option('--processes', type=int, default=1, help="1") +@click.option("-h", "--hostname", type=str, default="localhost", help="localhost") +@click.option("-p", "--port", type=int, default=5000, help="5000") +@click.option("--no-reloader", is_flag=True, default=False) +@click.option("--debugger", is_flag=True) +@click.option("--no-evalex", is_flag=True, default=False) +@click.option("--threaded", is_flag=True) +@click.option("--processes", type=int, default=1, help="1") def runserver(hostname, port, no_reloader, debugger, no_evalex, threaded, processes): """Start a new development server.""" app = make_app() reloader = not no_reloader evalex = not no_evalex - run_simple(hostname, port, app, - use_reloader=reloader, use_debugger=debugger, - use_evalex=evalex, threaded=threaded, processes=processes) + run_simple( + hostname, + port, + app, + use_reloader=reloader, + use_debugger=debugger, + use_evalex=evalex, + threaded=threaded, + processes=processes, + ) @cli.command() -@click.option('--no-ipython', is_flag=True, default=False) +@click.option("--no-ipython", is_flag=True, default=False) def shell(no_ipython): """Start a new interactive python session.""" - banner = 'Interactive Werkzeug Shell' - namespace = {'app': make_app()} + banner = "Interactive Werkzeug Shell" + namespace = {"app": make_app()} if not no_ipython: try: try: from IPython.frontend.terminal.embed import InteractiveShellEmbed + sh = InteractiveShellEmbed.instance(banner1=banner) except ImportError: from IPython.Shell import IPShellEmbed + sh = IPShellEmbed(banner=banner) except ImportError: pass @@ -95,6 +123,7 @@ def shell(no_ipython): sh(local_ns=namespace) return from code import interact + interact(banner, local=namespace) @@ -102,8 +131,10 @@ def shell(no_ipython): def sync(): """Sync the blogs in the planet. Call this from a cronjob.""" from plnt.sync import sync + make_app().bind_to_context() sync() -if __name__ == '__main__': + +if __name__ == "__main__": cli() diff --git a/examples/manage-shorty.py b/examples/manage-shorty.py index d80eccab..80fb3731 100755 --- a/examples/manage-shorty.py +++ b/examples/manage-shorty.py @@ -1,24 +1,23 @@ #!/usr/bin/env python -import click import os import tempfile + +import click from werkzeug.serving import run_simple def make_app(): from shorty.application import Shorty + filename = os.path.join(tempfile.gettempdir(), "shorty.db") - return Shorty('sqlite:///{0}'.format(filename)) + return Shorty("sqlite:///{0}".format(filename)) def make_shell(): from shorty import models, utils + application = make_app() - return { - "application": application, - "models": models, - "utils": utils, - } + return {"application": application, "models": models, "utils": utils} @click.group() @@ -32,36 +31,45 @@ def initdb(): @cli.command() -@click.option('-h', '--hostname', type=str, default='localhost', help="localhost") -@click.option('-p', '--port', type=int, default=5000, help="5000") -@click.option('--no-reloader', is_flag=True, default=False) -@click.option('--debugger', is_flag=True) -@click.option('--no-evalex', is_flag=True, default=False) -@click.option('--threaded', is_flag=True) -@click.option('--processes', type=int, default=1, help="1") +@click.option("-h", "--hostname", type=str, default="localhost", help="localhost") +@click.option("-p", "--port", type=int, default=5000, help="5000") +@click.option("--no-reloader", is_flag=True, default=False) +@click.option("--debugger", is_flag=True) +@click.option("--no-evalex", is_flag=True, default=False) +@click.option("--threaded", is_flag=True) +@click.option("--processes", type=int, default=1, help="1") def runserver(hostname, port, no_reloader, debugger, no_evalex, threaded, processes): """Start a new development server.""" app = make_app() reloader = not no_reloader evalex = not no_evalex - run_simple(hostname, port, app, - use_reloader=reloader, use_debugger=debugger, - use_evalex=evalex, threaded=threaded, processes=processes) + run_simple( + hostname, + port, + app, + use_reloader=reloader, + use_debugger=debugger, + use_evalex=evalex, + threaded=threaded, + processes=processes, + ) @cli.command() -@click.option('--no-ipython', is_flag=True, default=False) +@click.option("--no-ipython", is_flag=True, default=False) def shell(no_ipython): """Start a new interactive python session.""" - banner = 'Interactive Werkzeug Shell' + banner = "Interactive Werkzeug Shell" namespace = make_shell() if not no_ipython: try: try: from IPython.frontend.terminal.embed import InteractiveShellEmbed + sh = InteractiveShellEmbed.instance(banner1=banner) except ImportError: from IPython.Shell import IPShellEmbed + sh = IPShellEmbed(banner=banner) except ImportError: pass @@ -69,7 +77,9 @@ def shell(no_ipython): sh(local_ns=namespace) return from code import interact + interact(banner, local=namespace) -if __name__ == '__main__': + +if __name__ == "__main__": cli() diff --git a/examples/manage-simplewiki.py b/examples/manage-simplewiki.py index 3e23ceb1..5fd66b85 100755 --- a/examples/manage-simplewiki.py +++ b/examples/manage-simplewiki.py @@ -9,26 +9,26 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import click import os + +import click from werkzeug.serving import run_simple def make_wiki(): """Helper function that creates a new wiki instance.""" from simplewiki import SimpleWiki - database_uri = os.environ.get('SIMPLEWIKI_DATABASE_URI') - return SimpleWiki(database_uri or 'sqlite:////tmp/simplewiki.db') + + database_uri = os.environ.get("SIMPLEWIKI_DATABASE_URI") + return SimpleWiki(database_uri or "sqlite:////tmp/simplewiki.db") def make_shell(): from simplewiki import database + wiki = make_wiki() wiki.bind_to_context() - return { - 'wiki': wiki, - 'db': database - } + return {"wiki": wiki, "db": database} @click.group() @@ -42,36 +42,45 @@ def initdb(): @cli.command() -@click.option('-h', '--hostname', type=str, default='localhost', help="localhost") -@click.option('-p', '--port', type=int, default=5000, help="5000") -@click.option('--no-reloader', is_flag=True, default=False) -@click.option('--debugger', is_flag=True) -@click.option('--no-evalex', is_flag=True, default=False) -@click.option('--threaded', is_flag=True) -@click.option('--processes', type=int, default=1, help="1") +@click.option("-h", "--hostname", type=str, default="localhost", help="localhost") +@click.option("-p", "--port", type=int, default=5000, help="5000") +@click.option("--no-reloader", is_flag=True, default=False) +@click.option("--debugger", is_flag=True) +@click.option("--no-evalex", is_flag=True, default=False) +@click.option("--threaded", is_flag=True) +@click.option("--processes", type=int, default=1, help="1") def runserver(hostname, port, no_reloader, debugger, no_evalex, threaded, processes): """Start a new development server.""" app = make_wiki() reloader = not no_reloader evalex = not no_evalex - run_simple(hostname, port, app, - use_reloader=reloader, use_debugger=debugger, - use_evalex=evalex, threaded=threaded, processes=processes) + run_simple( + hostname, + port, + app, + use_reloader=reloader, + use_debugger=debugger, + use_evalex=evalex, + threaded=threaded, + processes=processes, + ) @cli.command() -@click.option('--no-ipython', is_flag=True, default=False) +@click.option("--no-ipython", is_flag=True, default=False) def shell(no_ipython): """Start a new interactive python session.""" - banner = 'Interactive Werkzeug Shell' + banner = "Interactive Werkzeug Shell" namespace = make_shell() if not no_ipython: try: try: from IPython.frontend.terminal.embed import InteractiveShellEmbed + sh = InteractiveShellEmbed.instance(banner1=banner) except ImportError: from IPython.Shell import IPShellEmbed + sh = IPShellEmbed(banner=banner) except ImportError: pass @@ -79,7 +88,9 @@ def shell(no_ipython): sh(local_ns=namespace) return from code import interact + interact(banner, local=namespace) -if __name__ == '__main__': + +if __name__ == "__main__": cli() diff --git a/examples/manage-webpylike.py b/examples/manage-webpylike.py index 2e77714d..010bdcee 100755 --- a/examples/manage-webpylike.py +++ b/examples/manage-webpylike.py @@ -13,13 +13,16 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import click import os import sys -sys.path.append(os.path.join(os.path.dirname(__file__), 'webpylike')) -from webpylike.example import app + +import click from werkzeug.serving import run_simple +from webpylike.example import app + +sys.path.append(os.path.join(os.path.dirname(__file__), "webpylike")) + @click.group() def cli(): @@ -27,35 +30,44 @@ def cli(): @cli.command() -@click.option('-h', '--hostname', type=str, default='localhost', help="localhost") -@click.option('-p', '--port', type=int, default=5000, help="5000") -@click.option('--no-reloader', is_flag=True, default=False) -@click.option('--debugger', is_flag=True) -@click.option('--no-evalex', is_flag=True, default=False) -@click.option('--threaded', is_flag=True) -@click.option('--processes', type=int, default=1, help="1") +@click.option("-h", "--hostname", type=str, default="localhost", help="localhost") +@click.option("-p", "--port", type=int, default=5000, help="5000") +@click.option("--no-reloader", is_flag=True, default=False) +@click.option("--debugger", is_flag=True) +@click.option("--no-evalex", is_flag=True, default=False) +@click.option("--threaded", is_flag=True) +@click.option("--processes", type=int, default=1, help="1") def runserver(hostname, port, no_reloader, debugger, no_evalex, threaded, processes): """Start a new development server.""" reloader = not no_reloader evalex = not no_evalex - run_simple(hostname, port, app, - use_reloader=reloader, use_debugger=debugger, - use_evalex=evalex, threaded=threaded, processes=processes) + run_simple( + hostname, + port, + app, + use_reloader=reloader, + use_debugger=debugger, + use_evalex=evalex, + threaded=threaded, + processes=processes, + ) @cli.command() -@click.option('--no-ipython', is_flag=True, default=False) +@click.option("--no-ipython", is_flag=True, default=False) def shell(no_ipython): """Start a new interactive python session.""" - banner = 'Interactive Werkzeug Shell' + banner = "Interactive Werkzeug Shell" namespace = dict() if not no_ipython: try: try: from IPython.frontend.terminal.embed import InteractiveShellEmbed + sh = InteractiveShellEmbed.instance(banner1=banner) except ImportError: from IPython.Shell import IPShellEmbed + sh = IPShellEmbed(banner=banner) except ImportError: pass @@ -63,7 +75,9 @@ def shell(no_ipython): sh(local_ns=namespace) return from code import interact + interact(banner, local=namespace) -if __name__ == '__main__': + +if __name__ == "__main__": cli() diff --git a/examples/partial/complex_routing.py b/examples/partial/complex_routing.py index 18ce7b0a..596d00e4 100644 --- a/examples/partial/complex_routing.py +++ b/examples/partial/complex_routing.py @@ -1,19 +1,43 @@ -from werkzeug.routing import Map, Rule, Subdomain, Submount, EndpointPrefix +from werkzeug.routing import EndpointPrefix +from werkzeug.routing import Map +from werkzeug.routing import Rule +from werkzeug.routing import Subdomain +from werkzeug.routing import Submount -m = Map([ - # Static URLs - EndpointPrefix('static/', [ - Rule('/', endpoint='index'), - Rule('/about', endpoint='about'), - Rule('/help', endpoint='help'), - ]), - # Knowledge Base - Subdomain('kb', [EndpointPrefix('kb/', [ - Rule('/', endpoint='index'), - Submount('/browse', [ - Rule('/', endpoint='browse'), - Rule('/<int:id>/', defaults={'page': 1}, endpoint='browse'), - Rule('/<int:id>/<int:page>', endpoint='browse') - ]) - ])]) -]) +m = Map( + [ + # Static URLs + EndpointPrefix( + "static/", + [ + Rule("/", endpoint="index"), + Rule("/about", endpoint="about"), + Rule("/help", endpoint="help"), + ], + ), + # Knowledge Base + Subdomain( + "kb", + [ + EndpointPrefix( + "kb/", + [ + Rule("/", endpoint="index"), + Submount( + "/browse", + [ + Rule("/", endpoint="browse"), + Rule( + "/<int:id>/", + defaults={"page": 1}, + endpoint="browse", + ), + Rule("/<int:id>/<int:page>", endpoint="browse"), + ], + ), + ], + ) + ], + ), + ] +) diff --git a/examples/plnt/__init__.py b/examples/plnt/__init__.py index f1bb94f7..7ade4b72 100644 --- a/examples/plnt/__init__.py +++ b/examples/plnt/__init__.py @@ -8,4 +8,4 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -from plnt.webapp import Plnt +from .webapp import Plnt diff --git a/examples/plnt/database.py b/examples/plnt/database.py index c74abb84..78e4ab6a 100644 --- a/examples/plnt/database.py +++ b/examples/plnt/database.py @@ -8,62 +8,73 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -from sqlalchemy import MetaData, Table, Column, ForeignKey, \ - Integer, String, DateTime -from sqlalchemy.orm import dynamic_loader, scoped_session, create_session, \ - mapper -from plnt.utils import application, local_manager +from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy.orm import create_session +from sqlalchemy.orm import dynamic_loader +from sqlalchemy.orm import mapper +from sqlalchemy.orm import scoped_session + +from .utils import application +from .utils import local_manager def new_db_session(): - return create_session(application.database_engine, autoflush=True, - autocommit=False) + return create_session(application.database_engine, autoflush=True, autocommit=False) + metadata = MetaData() session = scoped_session(new_db_session, local_manager.get_ident) -blog_table = Table('blogs', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(120)), - Column('description', String), - Column('url', String(200)), - Column('feed_url', String(250)) +blog_table = Table( + "blogs", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(120)), + Column("description", String), + Column("url", String(200)), + Column("feed_url", String(250)), ) -entry_table = Table('entries', metadata, - Column('id', Integer, primary_key=True), - Column('blog_id', Integer, ForeignKey('blogs.id')), - Column('guid', String(200), unique=True), - Column('title', String(140)), - Column('url', String(200)), - Column('text', String), - Column('pub_date', DateTime), - Column('last_update', DateTime) +entry_table = Table( + "entries", + metadata, + Column("id", Integer, primary_key=True), + Column("blog_id", Integer, ForeignKey("blogs.id")), + Column("guid", String(200), unique=True), + Column("title", String(140)), + Column("url", String(200)), + Column("text", String), + Column("pub_date", DateTime), + Column("last_update", DateTime), ) class Blog(object): query = session.query_property() - def __init__(self, name, url, feed_url, description=u''): + def __init__(self, name, url, feed_url, description=u""): self.name = name self.url = url self.feed_url = feed_url self.description = description def __repr__(self): - return '<%s %r>' % (self.__class__.__name__, self.url) + return "<%s %r>" % (self.__class__.__name__, self.url) class Entry(object): query = session.query_property() def __repr__(self): - return '<%s %r>' % (self.__class__.__name__, self.guid) + return "<%s %r>" % (self.__class__.__name__, self.guid) mapper(Entry, entry_table) -mapper(Blog, blog_table, properties=dict( - entries=dynamic_loader(Entry, backref='blog') -)) +mapper(Blog, blog_table, properties=dict(entries=dynamic_loader(Entry, backref="blog"))) diff --git a/examples/plnt/sync.py b/examples/plnt/sync.py index b0265d85..e12ab844 100644 --- a/examples/plnt/sync.py +++ b/examples/plnt/sync.py @@ -8,14 +8,19 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import feedparser from datetime import datetime + +import feedparser from werkzeug.utils import escape -from plnt.database import Blog, Entry, session -from plnt.utils import strip_tags, nl2p + +from .database import Blog +from .database import Entry +from .database import session +from .utils import nl2p +from .utils import strip_tags -HTML_MIMETYPES = {'text/html', 'application/xhtml+xml'} +HTML_MIMETYPES = {"text/html", "application/xhtml+xml"} def sync(): @@ -31,7 +36,7 @@ def sync(): for entry in feed.entries: # get the guid. either the id if specified, otherwise the link. # if none is available we skip the entry. - guid = entry.get('id') or entry.get('link') + guid = entry.get("id") or entry.get("link") if not guid: continue @@ -41,17 +46,18 @@ def sync(): # get title, url and text. skip if no title or no text is # given. if the link is missing we use the blog link. - if 'title_detail' in entry: - title = entry.title_detail.get('value') or '' - if entry.title_detail.get('type') in HTML_MIMETYPES: + if "title_detail" in entry: + title = entry.title_detail.get("value") or "" + if entry.title_detail.get("type") in HTML_MIMETYPES: title = strip_tags(title) else: title = escape(title) else: - title = entry.get('title') - url = entry.get('link') or blog.blog_url - text = 'content' in entry and entry.content[0] or \ - entry.get('summary_detail') + title = entry.get("title") + url = entry.get("link") or blog.blog_url + text = ( + "content" in entry and entry.content[0] or entry.get("summary_detail") + ) if not title or not text: continue @@ -59,10 +65,10 @@ def sync(): # if we have an html text we use that, otherwise we HTML # escape the text and use that one. We also handle XHTML # with our tag soup parser for the moment. - if text.get('type') not in HTML_MIMETYPES: - text = escape(nl2p(text.get('value') or '')) + if text.get("type") not in HTML_MIMETYPES: + text = escape(nl2p(text.get("value") or "")) else: - text = text.get('value') or '' + text = text.get("value") or "" # no text? continue if not text.strip(): @@ -70,10 +76,12 @@ def sync(): # get the pub date and updated date. This is rather complex # because different feeds do different stuff - pub_date = entry.get('published_parsed') or \ - entry.get('created_parsed') or \ - entry.get('date_parsed') - updated = entry.get('updated_parsed') or pub_date + pub_date = ( + entry.get("published_parsed") + or entry.get("created_parsed") + or entry.get("date_parsed") + ) + updated = entry.get("updated_parsed") or pub_date pub_date = pub_date or updated # if we don't have a pub_date we skip. diff --git a/examples/plnt/utils.py b/examples/plnt/utils.py index 957c28da..5c6f0d0b 100644 --- a/examples/plnt/utils.py +++ b/examples/plnt/utils.py @@ -10,13 +10,16 @@ """ import re from os import path -from jinja2 import Environment, FileSystemLoader +from jinja2 import Environment +from jinja2 import FileSystemLoader from werkzeug._compat import unichr -from werkzeug.local import Local, LocalManager +from werkzeug.local import Local +from werkzeug.local import LocalManager +from werkzeug.routing import Map +from werkzeug.routing import Rule from werkzeug.utils import cached_property from werkzeug.wrappers import Response -from werkzeug.routing import Map, Rule # context locals. these two objects are use by the application to @@ -29,24 +32,24 @@ local_manager = LocalManager([local]) # proxy objects -request = local('request') -application = local('application') -url_adapter = local('url_adapter') +request = local("request") +application = local("application") +url_adapter = local("url_adapter") # let's use jinja for templates this time -template_path = path.join(path.dirname(__file__), 'templates') +template_path = path.join(path.dirname(__file__), "templates") jinja_env = Environment(loader=FileSystemLoader(template_path)) # the collected url patterns -url_map = Map([Rule('/shared/<path:file>', endpoint='shared')]) +url_map = Map([Rule("/shared/<path:file>", endpoint="shared")]) endpoints = {} -_par_re = re.compile(r'\n{2,}') -_entity_re = re.compile(r'&([^;]+);') -_striptags_re = re.compile(r'(<!--.*-->|<[^>]*>)') +_par_re = re.compile(r"\n{2,}") +_entity_re = re.compile(r"&([^;]+);") +_striptags_re = re.compile(r"(<!--.*-->|<[^>]*>)") try: from html.entities import name2codepoint @@ -54,30 +57,32 @@ except ImportError: from htmlentitydefs import name2codepoint html_entities = name2codepoint.copy() -html_entities['apos'] = 39 +html_entities["apos"] = 39 del name2codepoint def expose(url_rule, endpoint=None, **kwargs): """Expose this function to the web layer.""" + def decorate(f): e = endpoint or f.__name__ endpoints[e] = f url_map.add(Rule(url_rule, endpoint=e, **kwargs)) return f + return decorate def render_template(template_name, **context): """Render a template into a response.""" tmpl = jinja_env.get_template(template_name) - context['url_for'] = url_for - return Response(tmpl.render(context), mimetype='text/html') + context["url_for"] = url_for + return Response(tmpl.render(context), mimetype="text/html") def nl2p(s): """Add paragraphs to a text.""" - return u'\n'.join(u'<p>%s</p>' % p for p in _par_re.split(s)) + return u"\n".join(u"<p>%s</p>" % p for p in _par_re.split(s)) def url_for(endpoint, **kw): @@ -87,22 +92,24 @@ def url_for(endpoint, **kw): def strip_tags(s): """Resolve HTML entities and remove tags from a string.""" + def handle_match(m): name = m.group(1) if name in html_entities: return unichr(html_entities[name]) - if name[:2] in ('#x', '#X'): + if name[:2] in ("#x", "#X"): try: return unichr(int(name[2:], 16)) except ValueError: - return u'' - elif name.startswith('#'): + return u"" + elif name.startswith("#"): try: return unichr(int(name[1:])) except ValueError: - return u'' - return u'' - return _entity_re.sub(handle_match, _striptags_re.sub('', s)) + return u"" + return u"" + + return _entity_re.sub(handle_match, _striptags_re.sub("", s)) class Pagination(object): @@ -118,8 +125,11 @@ class Pagination(object): @cached_property def entries(self): - return self.query.offset((self.page - 1) * self.per_page) \ - .limit(self.per_page).all() + return ( + self.query.offset((self.page - 1) * self.per_page) + .limit(self.per_page) + .all() + ) @cached_property def count(self): diff --git a/examples/plnt/views.py b/examples/plnt/views.py index 7397531c..d64e98e3 100644 --- a/examples/plnt/views.py +++ b/examples/plnt/views.py @@ -9,32 +9,35 @@ :license: BSD-3-Clause """ from datetime import date -from plnt.database import Entry -from plnt.utils import Pagination, expose, render_template + +from .database import Entry +from .utils import expose +from .utils import Pagination +from .utils import render_template #: number of items per page PER_PAGE = 30 -@expose('/', defaults={'page': 1}) -@expose('/page/<int:page>') +@expose("/", defaults={"page": 1}) +@expose("/page/<int:page>") def index(request, page): """Show the index page or any an offset of it.""" days = [] days_found = set() query = Entry.query.order_by(Entry.pub_date.desc()) - pagination = Pagination(query, PER_PAGE, page, 'index') + pagination = Pagination(query, PER_PAGE, page, "index") for entry in pagination.entries: day = date(*entry.pub_date.timetuple()[:3]) if day not in days_found: days_found.add(day) - days.append({'date': day, 'entries': []}) - days[-1]['entries'].append(entry) - return render_template('index.html', days=days, pagination=pagination) + days.append({"date": day, "entries": []}) + days[-1]["entries"].append(entry) + return render_template("index.html", days=days, pagination=pagination) -@expose('/about') +@expose("/about") def about(request): """Show the about page, so that we have another view func ;-)""" - return render_template('about.html') + return render_template("about.html") diff --git a/examples/plnt/webapp.py b/examples/plnt/webapp.py index b571841a..dee7336b 100644 --- a/examples/plnt/webapp.py +++ b/examples/plnt/webapp.py @@ -9,31 +9,31 @@ :license: BSD-3-Clause """ from os import path -from sqlalchemy import create_engine +from sqlalchemy import create_engine +from werkzeug.exceptions import HTTPException from werkzeug.middleware.shared_data import SharedDataMiddleware from werkzeug.wrappers import Request from werkzeug.wsgi import ClosingIterator -from werkzeug.exceptions import HTTPException -from plnt.utils import local, local_manager, url_map, endpoints -from plnt.database import session, metadata -# import the views module because it contains setup code -import plnt.views +from . import views # noqa: F401 +from .database import metadata +from .database import session +from .utils import endpoints +from .utils import local +from .utils import local_manager +from .utils import url_map #: path to shared data -SHARED_DATA = path.join(path.dirname(__file__), 'shared') +SHARED_DATA = path.join(path.dirname(__file__), "shared") class Plnt(object): - def __init__(self, database_uri): self.database_engine = create_engine(database_uri) self._dispatch = local_manager.middleware(self.dispatch_request) - self._dispatch = SharedDataMiddleware(self._dispatch, { - '/shared': SHARED_DATA - }) + self._dispatch = SharedDataMiddleware(self._dispatch, {"/shared": SHARED_DATA}) def init_database(self): metadata.create_all(self.database_engine) @@ -50,8 +50,7 @@ class Plnt(object): response = endpoints[endpoint](request, **values) except HTTPException as e: response = e - return ClosingIterator(response(environ, start_response), - session.remove) + return ClosingIterator(response(environ, start_response), session.remove) def __call__(self, environ, start_response): return self._dispatch(environ, start_response) diff --git a/examples/shortly/shortly.py b/examples/shortly/shortly.py index 3446de0f..a3ff9f7f 100644 --- a/examples/shortly/shortly.py +++ b/examples/shortly/shortly.py @@ -9,32 +9,35 @@ :license: BSD-3-Clause """ import os -import redis +import redis +from jinja2 import Environment +from jinja2 import FileSystemLoader +from werkzeug.exceptions import HTTPException +from werkzeug.exceptions import NotFound from werkzeug.middleware.shared_data import SharedDataMiddleware +from werkzeug.routing import Map +from werkzeug.routing import Rule from werkzeug.urls import url_parse -from werkzeug.wrappers import Request, Response -from werkzeug.routing import Map, Rule -from werkzeug.exceptions import HTTPException, NotFound from werkzeug.utils import redirect - -from jinja2 import Environment, FileSystemLoader +from werkzeug.wrappers import Request +from werkzeug.wrappers import Response def base36_encode(number): - assert number >= 0, 'positive integer required' + assert number >= 0, "positive integer required" if number == 0: - return '0' + return "0" base36 = [] while number != 0: number, i = divmod(number, 36) - base36.append('0123456789abcdefghijklmnopqrstuvwxyz'[i]) - return ''.join(reversed(base36)) + base36.append("0123456789abcdefghijklmnopqrstuvwxyz"[i]) + return "".join(reversed(base36)) def is_valid_url(url): parts = url_parse(url) - return parts.scheme in ('http', 'https') + return parts.scheme in ("http", "https") def get_hostname(url): @@ -42,74 +45,77 @@ def get_hostname(url): class Shortly(object): - def __init__(self, config): - self.redis = redis.Redis(config['redis_host'], config['redis_port']) - template_path = os.path.join(os.path.dirname(__file__), 'templates') - self.jinja_env = Environment(loader=FileSystemLoader(template_path), - autoescape=True) - self.jinja_env.filters['hostname'] = get_hostname - - self.url_map = Map([ - Rule('/', endpoint='new_url'), - Rule('/<short_id>', endpoint='follow_short_link'), - Rule('/<short_id>+', endpoint='short_link_details') - ]) + self.redis = redis.Redis(config["redis_host"], config["redis_port"]) + template_path = os.path.join(os.path.dirname(__file__), "templates") + self.jinja_env = Environment( + loader=FileSystemLoader(template_path), autoescape=True + ) + self.jinja_env.filters["hostname"] = get_hostname + + self.url_map = Map( + [ + Rule("/", endpoint="new_url"), + Rule("/<short_id>", endpoint="follow_short_link"), + Rule("/<short_id>+", endpoint="short_link_details"), + ] + ) def on_new_url(self, request): error = None - url = '' - if request.method == 'POST': - url = request.form['url'] + url = "" + if request.method == "POST": + url = request.form["url"] if not is_valid_url(url): - error = 'Please enter a valid URL' + error = "Please enter a valid URL" else: short_id = self.insert_url(url) - return redirect('/%s+' % short_id) - return self.render_template('new_url.html', error=error, url=url) + return redirect("/%s+" % short_id) + return self.render_template("new_url.html", error=error, url=url) def on_follow_short_link(self, request, short_id): - link_target = self.redis.get('url-target:' + short_id) + link_target = self.redis.get("url-target:" + short_id) if link_target is None: raise NotFound() - self.redis.incr('click-count:' + short_id) + self.redis.incr("click-count:" + short_id) return redirect(link_target) def on_short_link_details(self, request, short_id): - link_target = self.redis.get('url-target:' + short_id) + link_target = self.redis.get("url-target:" + short_id) if link_target is None: raise NotFound() - click_count = int(self.redis.get('click-count:' + short_id) or 0) - return self.render_template('short_link_details.html', + click_count = int(self.redis.get("click-count:" + short_id) or 0) + return self.render_template( + "short_link_details.html", link_target=link_target, short_id=short_id, - click_count=click_count + click_count=click_count, ) def error_404(self): - response = self.render_template('404.html') + response = self.render_template("404.html") response.status_code = 404 return response def insert_url(self, url): - short_id = self.redis.get('reverse-url:' + url) + short_id = self.redis.get("reverse-url:" + url) if short_id is not None: return short_id - url_num = self.redis.incr('last-url-id') + url_num = self.redis.incr("last-url-id") short_id = base36_encode(url_num) - self.redis.set('url-target:' + short_id, url) - self.redis.set('reverse-url:' + url, short_id) + self.redis.set("url-target:" + short_id, url) + self.redis.set("reverse-url:" + url, short_id) return short_id def render_template(self, template_name, **context): t = self.jinja_env.get_template(template_name) - return Response(t.render(context), mimetype='text/html') + return Response(t.render(context), mimetype="text/html") def dispatch_request(self, request): adapter = self.url_map.bind_to_environ(request.environ) try: endpoint, values = adapter.match() - return getattr(self, 'on_' + endpoint)(request, **values) + return getattr(self, "on_" + endpoint)(request, **values) except NotFound: return self.error_404() except HTTPException as e: @@ -124,19 +130,17 @@ class Shortly(object): return self.wsgi_app(environ, start_response) -def create_app(redis_host='localhost', redis_port=6379, with_static=True): - app = Shortly({ - 'redis_host': redis_host, - 'redis_port': redis_port - }) +def create_app(redis_host="localhost", redis_port=6379, with_static=True): + app = Shortly({"redis_host": redis_host, "redis_port": redis_port}) if with_static: - app.wsgi_app = SharedDataMiddleware(app.wsgi_app, { - '/static': os.path.join(os.path.dirname(__file__), 'static') - }) + app.wsgi_app = SharedDataMiddleware( + app.wsgi_app, {"/static": os.path.join(os.path.dirname(__file__), "static")} + ) return app -if __name__ == '__main__': +if __name__ == "__main__": from werkzeug.serving import run_simple + app = create_app() - run_simple('127.0.0.1', 5000, app, use_debugger=True, use_reloader=True) + run_simple("127.0.0.1", 5000, app, use_debugger=True, use_reloader=True) diff --git a/examples/shorty/application.py b/examples/shorty/application.py index cfec075f..af793444 100644 --- a/examples/shorty/application.py +++ b/examples/shorty/application.py @@ -1,24 +1,25 @@ from sqlalchemy import create_engine - +from werkzeug.exceptions import HTTPException +from werkzeug.exceptions import NotFound from werkzeug.middleware.shared_data import SharedDataMiddleware from werkzeug.wrappers import Request from werkzeug.wsgi import ClosingIterator -from werkzeug.exceptions import HTTPException, NotFound -from shorty.utils import STATIC_PATH, session, local, local_manager, \ - metadata, url_map -from shorty import views +from . import views +from .utils import local +from .utils import local_manager +from .utils import metadata +from .utils import session +from .utils import STATIC_PATH +from .utils import url_map class Shorty(object): - def __init__(self, db_uri): local.application = self self.database_engine = create_engine(db_uri, convert_unicode=True) - self.dispatch = SharedDataMiddleware(self.dispatch, { - '/static': STATIC_PATH - }) + self.dispatch = SharedDataMiddleware(self.dispatch, {"/static": STATIC_PATH}) def init_database(self): metadata.create_all(self.database_engine) @@ -36,8 +37,9 @@ class Shorty(object): response.status_code = 404 except HTTPException as e: response = e - return ClosingIterator(response(environ, start_response), - [session.remove, local_manager.cleanup]) + return ClosingIterator( + response(environ, start_response), [session.remove, local_manager.cleanup] + ) def __call__(self, environ, start_response): return self.dispatch(environ, start_response) diff --git a/examples/shorty/models.py b/examples/shorty/models.py index 78b5f1df..7d0df5bd 100644 --- a/examples/shorty/models.py +++ b/examples/shorty/models.py @@ -1,15 +1,27 @@ from datetime import datetime -from sqlalchemy import Table, Column, String, Boolean, DateTime + +from sqlalchemy import Boolean +from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlalchemy import String +from sqlalchemy import Table from sqlalchemy.orm import mapper -from shorty.utils import session, metadata, url_for, get_random_uid -url_table = Table('urls', metadata, - Column('uid', String(140), primary_key=True), - Column('target', String(500)), - Column('added', DateTime), - Column('public', Boolean) +from .utils import get_random_uid +from .utils import metadata +from .utils import session +from .utils import url_for + +url_table = Table( + "urls", + metadata, + Column("uid", String(140), primary_key=True), + Column("target", String(500)), + Column("added", DateTime), + Column("public", Boolean), ) + class URL(object): query = session.query_property() @@ -27,9 +39,10 @@ class URL(object): @property def short_url(self): - return url_for('link', uid=self.uid, _external=True) + return url_for("link", uid=self.uid, _external=True) def __repr__(self): - return '<URL %r>' % self.uid + return "<URL %r>" % self.uid + mapper(URL, url_table) diff --git a/examples/shorty/utils.py b/examples/shorty/utils.py index bbcb0f20..f61f9ef9 100644 --- a/examples/shorty/utils.py +++ b/examples/shorty/utils.py @@ -1,57 +1,72 @@ from os import path -from random import sample, randrange -from jinja2 import Environment, FileSystemLoader -from werkzeug.local import Local, LocalManager +from random import randrange +from random import sample + +from jinja2 import Environment +from jinja2 import FileSystemLoader +from sqlalchemy import MetaData +from sqlalchemy.orm import create_session +from sqlalchemy.orm import scoped_session +from werkzeug.local import Local +from werkzeug.local import LocalManager +from werkzeug.routing import Map +from werkzeug.routing import Rule from werkzeug.urls import url_parse from werkzeug.utils import cached_property from werkzeug.wrappers import Response -from werkzeug.routing import Map, Rule -from sqlalchemy import MetaData -from sqlalchemy.orm import create_session, scoped_session -TEMPLATE_PATH = path.join(path.dirname(__file__), 'templates') -STATIC_PATH = path.join(path.dirname(__file__), 'static') -ALLOWED_SCHEMES = frozenset(['http', 'https', 'ftp', 'ftps']) -URL_CHARS = 'abcdefghijkmpqrstuvwxyzABCDEFGHIJKLMNPQRST23456789' +TEMPLATE_PATH = path.join(path.dirname(__file__), "templates") +STATIC_PATH = path.join(path.dirname(__file__), "static") +ALLOWED_SCHEMES = frozenset(["http", "https", "ftp", "ftps"]) +URL_CHARS = "abcdefghijkmpqrstuvwxyzABCDEFGHIJKLMNPQRST23456789" local = Local() local_manager = LocalManager([local]) -application = local('application') +application = local("application") metadata = MetaData() -url_map = Map([Rule('/static/<file>', endpoint='static', build_only=True)]) +url_map = Map([Rule("/static/<file>", endpoint="static", build_only=True)]) -session = scoped_session(lambda: create_session(application.database_engine, - autocommit=False, - autoflush=False)) +session = scoped_session( + lambda: create_session( + application.database_engine, autocommit=False, autoflush=False + ) +) jinja_env = Environment(loader=FileSystemLoader(TEMPLATE_PATH)) def expose(rule, **kw): def decorate(f): - kw['endpoint'] = f.__name__ + kw["endpoint"] = f.__name__ url_map.add(Rule(rule, **kw)) return f + return decorate + def url_for(endpoint, _external=False, **values): return local.url_adapter.build(endpoint, values, force_external=_external) -jinja_env.globals['url_for'] = url_for + + +jinja_env.globals["url_for"] = url_for + def render_template(template, **context): - return Response(jinja_env.get_template(template).render(**context), - mimetype='text/html') + return Response( + jinja_env.get_template(template).render(**context), mimetype="text/html" + ) + def validate_url(url): return url_parse(url)[0] in ALLOWED_SCHEMES + def get_random_uid(): - return ''.join(sample(URL_CHARS, randrange(3, 9))) + return "".join(sample(URL_CHARS, randrange(3, 9))) class Pagination(object): - def __init__(self, query, per_page, page, endpoint): self.query = query self.per_page = per_page @@ -64,8 +79,11 @@ class Pagination(object): @cached_property def entries(self): - return self.query.offset((self.page - 1) * self.per_page) \ - .limit(self.per_page).all() + return ( + self.query.offset((self.page - 1) * self.per_page) + .limit(self.per_page) + .all() + ) has_previous = property(lambda self: self.page > 1) has_next = property(lambda self: self.page < self.pages) diff --git a/examples/shorty/views.py b/examples/shorty/views.py index fa4269a3..7a1ee20b 100644 --- a/examples/shorty/views.py +++ b/examples/shorty/views.py @@ -1,52 +1,62 @@ -from werkzeug.utils import redirect from werkzeug.exceptions import NotFound -from shorty.utils import session, Pagination, render_template, expose, \ - validate_url, url_for -from shorty.models import URL +from werkzeug.utils import redirect + +from .models import URL +from .utils import expose +from .utils import Pagination +from .utils import render_template +from .utils import session +from .utils import url_for +from .utils import validate_url + -@expose('/') +@expose("/") def new(request): - error = url = '' - if request.method == 'POST': - url = request.form.get('url') - alias = request.form.get('alias') + error = url = "" + if request.method == "POST": + url = request.form.get("url") + alias = request.form.get("alias") if not validate_url(url): error = "I'm sorry but you cannot shorten this URL." elif alias: if len(alias) > 140: - error = 'Your alias is too long' - elif '/' in alias: - error = 'Your alias might not include a slash' + error = "Your alias is too long" + elif "/" in alias: + error = "Your alias might not include a slash" elif URL.query.get(alias): - error = 'The alias you have requested exists already' + error = "The alias you have requested exists already" if not error: - uid = URL(url, 'private' not in request.form, alias).uid + uid = URL(url, "private" not in request.form, alias).uid session.commit() - return redirect(url_for('display', uid=uid)) - return render_template('new.html', error=error, url=url) + return redirect(url_for("display", uid=uid)) + return render_template("new.html", error=error, url=url) -@expose('/display/<uid>') + +@expose("/display/<uid>") def display(request, uid): url = URL.query.get(uid) if not url: raise NotFound() - return render_template('display.html', url=url) + return render_template("display.html", url=url) + -@expose('/u/<uid>') +@expose("/u/<uid>") def link(request, uid): url = URL.query.get(uid) if not url: raise NotFound() return redirect(url.target, 301) -@expose('/list/', defaults={'page': 1}) -@expose('/list/<int:page>') + +@expose("/list/", defaults={"page": 1}) +@expose("/list/<int:page>") def list(request, page): query = URL.query.filter_by(public=True) - pagination = Pagination(query, 30, page, 'list') + pagination = Pagination(query, 30, page, "list") if pagination.page > 1 and not pagination.entries: raise NotFound() - return render_template('list.html', pagination=pagination) + return render_template("list.html", pagination=pagination) + def not_found(request): - return render_template('not_found.html') + return render_template("not_found.html") diff --git a/examples/simplewiki/__init__.py b/examples/simplewiki/__init__.py index c74c2cf6..591a06e9 100644 --- a/examples/simplewiki/__init__.py +++ b/examples/simplewiki/__init__.py @@ -9,4 +9,4 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -from simplewiki.application import SimpleWiki +from .application import SimpleWiki diff --git a/examples/simplewiki/actions.py b/examples/simplewiki/actions.py index 29b2ff5b..c9a88b41 100644 --- a/examples/simplewiki/actions.py +++ b/examples/simplewiki/actions.py @@ -13,15 +13,22 @@ :license: BSD-3-Clause """ from difflib import unified_diff -from simplewiki.utils import Response, generate_template, \ - href, format_datetime -from simplewiki.database import RevisionedPage, Page, Revision, session + from werkzeug.utils import redirect +from .database import Page +from .database import Revision +from .database import RevisionedPage +from .database import session +from .utils import format_datetime +from .utils import generate_template +from .utils import href +from .utils import Response + def on_show(request, page_name): """Displays the page the user requests.""" - revision_id = request.args.get('rev', type=int) + revision_id = request.args.get("rev", type=int) query = RevisionedPage.query.filter_by(name=page_name) if revision_id: query = query.filter_by(revision_id=revision_id) @@ -32,32 +39,32 @@ def on_show(request, page_name): page = query.first() if page is None: return page_missing(request, page_name, revision_requested) - return Response(generate_template('action_show.html', - page=page - )) + return Response(generate_template("action_show.html", page=page)) def on_edit(request, page_name): """Edit the current revision of a page.""" - change_note = error = '' - revision = Revision.query.filter( - (Page.name == page_name) & - (Page.page_id == Revision.page_id) - ).order_by(Revision.revision_id.desc()).first() + change_note = error = "" + revision = ( + Revision.query.filter( + (Page.name == page_name) & (Page.page_id == Revision.page_id) + ) + .order_by(Revision.revision_id.desc()) + .first() + ) if revision is None: page = None else: page = revision.page - if request.method == 'POST': - text = request.form.get('text') - if request.form.get('cancel') or \ - revision and revision.text == text: + if request.method == "POST": + text = request.form.get("text") + if request.form.get("cancel") or revision and revision.text == text: return redirect(href(page.name)) elif not text: - error = 'You cannot save empty revisions.' + error = "You cannot save empty revisions." else: - change_note = request.form.get('change_note', '') + change_note = request.form.get("change_note", "") if page is None: page = Page(page_name) session.add(page) @@ -65,14 +72,17 @@ def on_edit(request, page_name): session.commit() return redirect(href(page.name)) - return Response(generate_template('action_edit.html', - revision=revision, - page=page, - new=page is None, - page_name=page_name, - change_note=change_note, - error=error - )) + return Response( + generate_template( + "action_edit.html", + revision=revision, + page=page, + new=page is None, + page_name=page_name, + change_note=change_note, + error=error, + ) + ) def on_log(request, page_name): @@ -80,109 +90,117 @@ def on_log(request, page_name): page = Page.query.filter_by(name=page_name).first() if page is None: return page_missing(request, page_name, False) - return Response(generate_template('action_log.html', - page=page - )) + return Response(generate_template("action_log.html", page=page)) def on_diff(request, page_name): """Show the diff between two revisions.""" - old = request.args.get('old', type=int) - new = request.args.get('new', type=int) - error = '' + old = request.args.get("old", type=int) + new = request.args.get("new", type=int) + error = "" diff = page = old_rev = new_rev = None if not (old and new): - error = 'No revisions specified.' + error = "No revisions specified." else: - revisions = dict((x.revision_id, x) for x in Revision.query.filter( - (Revision.revision_id.in_((old, new))) & - (Revision.page_id == Page.page_id) & - (Page.name == page_name) - )) + revisions = dict( + (x.revision_id, x) + for x in Revision.query.filter( + (Revision.revision_id.in_((old, new))) + & (Revision.page_id == Page.page_id) + & (Page.name == page_name) + ) + ) if len(revisions) != 2: - error = 'At least one of the revisions requested ' \ - 'does not exist.' + error = "At least one of the revisions requested does not exist." else: new_rev = revisions[new] old_rev = revisions[old] page = old_rev.page diff = unified_diff( - (old_rev.text + '\n').splitlines(True), - (new_rev.text + '\n').splitlines(True), - page.name, page.name, + (old_rev.text + "\n").splitlines(True), + (new_rev.text + "\n").splitlines(True), + page.name, + page.name, format_datetime(old_rev.timestamp), format_datetime(new_rev.timestamp), - 3 + 3, ) - return Response(generate_template('action_diff.html', - error=error, - old_revision=old_rev, - new_revision=new_rev, - page=page, - diff=diff - )) + return Response( + generate_template( + "action_diff.html", + error=error, + old_revision=old_rev, + new_revision=new_rev, + page=page, + diff=diff, + ) + ) def on_revert(request, page_name): """Revert an old revision.""" - rev_id = request.args.get('rev', type=int) + rev_id = request.args.get("rev", type=int) old_revision = page = None - error = 'No such revision' + error = "No such revision" - if request.method == 'POST' and request.form.get('cancel'): + if request.method == "POST" and request.form.get("cancel"): return redirect(href(page_name)) if rev_id: old_revision = Revision.query.filter( - (Revision.revision_id == rev_id) & - (Revision.page_id == Page.page_id) & - (Page.name == page_name) + (Revision.revision_id == rev_id) + & (Revision.page_id == Page.page_id) + & (Page.name == page_name) ).first() if old_revision: - new_revision = Revision.query.filter( - (Revision.page_id == Page.page_id) & - (Page.name == page_name) - ).order_by(Revision.revision_id.desc()).first() + new_revision = ( + Revision.query.filter( + (Revision.page_id == Page.page_id) & (Page.name == page_name) + ) + .order_by(Revision.revision_id.desc()) + .first() + ) if old_revision == new_revision: - error = 'You tried to revert the current active ' \ - 'revision.' + error = "You tried to revert the current active revision." elif old_revision.text == new_revision.text: - error = 'There are no changes between the current ' \ - 'revision and the revision you want to ' \ - 'restore.' + error = ( + "There are no changes between the current " + "revision and the revision you want to " + "restore." + ) else: - error = '' + error = "" page = old_revision.page - if request.method == 'POST': - change_note = request.form.get('change_note', '') - change_note = 'revert' + (change_note and ': ' + - change_note or '') - session.add(Revision(page, old_revision.text, - change_note)) + if request.method == "POST": + change_note = request.form.get("change_note", "") + change_note = "revert" + (change_note and ": " + change_note or "") + session.add(Revision(page, old_revision.text, change_note)) session.commit() return redirect(href(page_name)) - return Response(generate_template('action_revert.html', - error=error, - old_revision=old_revision, - page=page - )) + return Response( + generate_template( + "action_revert.html", error=error, old_revision=old_revision, page=page + ) + ) def page_missing(request, page_name, revision_requested, protected=False): """Displayed if page or revision does not exist.""" - return Response(generate_template('page_missing.html', - page_name=page_name, - revision_requested=revision_requested, - protected=protected - ), status=404) + return Response( + generate_template( + "page_missing.html", + page_name=page_name, + revision_requested=revision_requested, + protected=protected, + ), + status=404, + ) def missing_action(request, action): """Displayed if a user tried to access a action that does not exist.""" - return Response(generate_template('missing_action.html', - action=action - ), status=404) + return Response(generate_template("missing_action.html", action=action), status=404) diff --git a/examples/simplewiki/application.py b/examples/simplewiki/application.py index b994b17d..5f85eef3 100644 --- a/examples/simplewiki/application.py +++ b/examples/simplewiki/application.py @@ -11,19 +11,25 @@ :license: BSD-3-Clause """ from os import path -from sqlalchemy import create_engine +from sqlalchemy import create_engine from werkzeug.middleware.shared_data import SharedDataMiddleware from werkzeug.utils import redirect from werkzeug.wsgi import ClosingIterator -from simplewiki.utils import Request, local, local_manager, href -from simplewiki.database import session, metadata -from simplewiki import actions -from simplewiki.specialpages import pages, page_not_found + +from . import actions +from .database import metadata +from .database import session +from .specialpages import page_not_found +from .specialpages import pages +from .utils import href +from .utils import local +from .utils import local_manager +from .utils import Request #: path to shared data -SHARED_DATA = path.join(path.dirname(__file__), 'shared') +SHARED_DATA = path.join(path.dirname(__file__), "shared") class SimpleWiki(object): @@ -37,9 +43,9 @@ class SimpleWiki(object): # apply our middlewares. we apply the middlewars *inside* the # application and not outside of it so that we never lose the # reference to the `SimpleWiki` object. - self._dispatch = SharedDataMiddleware(self.dispatch_request, { - '/_shared': SHARED_DATA - }) + self._dispatch = SharedDataMiddleware( + self.dispatch_request, {"/_shared": SHARED_DATA} + ) # free the context locals at the end of the request self._dispatch = local_manager.make_middleware(self._dispatch) @@ -66,16 +72,15 @@ class SimpleWiki(object): # get the current action from the url and normalize the page name # which is just the request path - action_name = request.args.get('action') or 'show' - page_name = u'_'.join([x for x in request.path.strip('/') - .split() if x]) + action_name = request.args.get("action") or "show" + page_name = u"_".join([x for x in request.path.strip("/").split() if x]) # redirect to the Main_Page if the user requested the index if not page_name: - response = redirect(href('Main_Page')) + response = redirect(href("Main_Page")) # check special pages - elif page_name.startswith('Special:'): + elif page_name.startswith("Special:"): if page_name[8:] not in pages: response = page_not_found(request, page_name) else: @@ -85,15 +90,14 @@ class SimpleWiki(object): # action module. It's "on_" + the action name. If it doesn't # exists call the missing_action method from the same module. else: - action = getattr(actions, 'on_' + action_name, None) + action = getattr(actions, "on_" + action_name, None) if action is None: response = actions.missing_action(request, action_name) else: response = action(request, page_name) # make sure the session is removed properly - return ClosingIterator(response(environ, start_response), - session.remove) + return ClosingIterator(response(environ, start_response), session.remove) def __call__(self, environ, start_response): """Just forward a WSGI call to the first internal middleware.""" diff --git a/examples/simplewiki/database.py b/examples/simplewiki/database.py index d5df7387..f0cec34e 100644 --- a/examples/simplewiki/database.py +++ b/examples/simplewiki/database.py @@ -9,11 +9,23 @@ :license: BSD-3-Clause """ from datetime import datetime -from sqlalchemy import Table, Column, Integer, String, DateTime, \ - ForeignKey, MetaData, join -from sqlalchemy.orm import relation, create_session, scoped_session, \ - mapper -from simplewiki.utils import application, local_manager, parse_creole + +from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import join +from sqlalchemy import MetaData +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy.orm import create_session +from sqlalchemy.orm import mapper +from sqlalchemy.orm import relation +from sqlalchemy.orm import scoped_session + +from .utils import application +from .utils import local_manager +from .utils import parse_creole # create a global metadata @@ -28,8 +40,7 @@ def new_db_session(): application. If there is no application bound to the context it raises an exception. """ - return create_session(application.database_engine, autoflush=True, - autocommit=False) + return create_session(application.database_engine, autoflush=True, autocommit=False) # and create a new global session factory. Calling this object gives @@ -38,17 +49,21 @@ session = scoped_session(new_db_session, local_manager.get_ident) # our database tables. -page_table = Table('pages', metadata, - Column('page_id', Integer, primary_key=True), - Column('name', String(60), unique=True) +page_table = Table( + "pages", + metadata, + Column("page_id", Integer, primary_key=True), + Column("name", String(60), unique=True), ) -revision_table = Table('revisions', metadata, - Column('revision_id', Integer, primary_key=True), - Column('page_id', Integer, ForeignKey('pages.page_id')), - Column('timestamp', DateTime), - Column('text', String), - Column('change_note', String(200)) +revision_table = Table( + "revisions", + metadata, + Column("revision_id", Integer, primary_key=True), + Column("page_id", Integer, ForeignKey("pages.page_id")), + Column("timestamp", DateTime), + Column("text", String), + Column("change_note", String(200)), ) @@ -59,9 +74,10 @@ class Revision(object): new revisions. It's also used for the diff system and the revision log. """ + query = session.query_property() - def __init__(self, page, text, change_note='', timestamp=None): + def __init__(self, page, text, change_note="", timestamp=None): if isinstance(page, int): self.page_id = page else: @@ -75,11 +91,7 @@ class Revision(object): return parse_creole(self.text) def __repr__(self): - return '<%s %r:%r>' % ( - self.__class__.__name__, - self.page_id, - self.revision_id - ) + return "<%s %r:%r>" % (self.__class__.__name__, self.page_id, self.revision_id) class Page(object): @@ -87,6 +99,7 @@ class Page(object): Represents a simple page without any revisions. This is for example used in the page index where the page contents are not relevant. """ + query = session.query_property() def __init__(self, name): @@ -94,10 +107,10 @@ class Page(object): @property def title(self): - return self.name.replace('_', ' ') + return self.name.replace("_", " ") def __repr__(self): - return '<%s %r>' % (self.__class__.__name__, self.name) + return "<%s %r>" % (self.__class__.__name__, self.name) class RevisionedPage(Page, Revision): @@ -106,26 +119,32 @@ class RevisionedPage(Page, Revision): and the ability of SQLAlchemy to map to joins we can combine `Page` and `Revision` into one class here. """ + query = session.query_property() def __init__(self): - raise TypeError('cannot create WikiPage instances, use the Page and ' - 'Revision classes for data manipulation.') + raise TypeError( + "cannot create WikiPage instances, use the Page and " + "Revision classes for data manipulation." + ) def __repr__(self): - return '<%s %r:%r>' % ( - self.__class__.__name__, - self.name, - self.revision_id - ) + return "<%s %r:%r>" % (self.__class__.__name__, self.name, self.revision_id) # setup mappers mapper(Revision, revision_table) -mapper(Page, page_table, properties=dict( - revisions=relation(Revision, backref='page', - order_by=Revision.revision_id.desc()) -)) -mapper(RevisionedPage, join(page_table, revision_table), properties=dict( - page_id=[page_table.c.page_id, revision_table.c.page_id], -)) +mapper( + Page, + page_table, + properties=dict( + revisions=relation( + Revision, backref="page", order_by=Revision.revision_id.desc() + ) + ), +) +mapper( + RevisionedPage, + join(page_table, revision_table), + properties=dict(page_id=[page_table.c.page_id, revision_table.c.page_id]), +) diff --git a/examples/simplewiki/specialpages.py b/examples/simplewiki/specialpages.py index 9717d6c1..9636f7ca 100644 --- a/examples/simplewiki/specialpages.py +++ b/examples/simplewiki/specialpages.py @@ -9,10 +9,12 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -from simplewiki.utils import Response, Pagination, generate_template -from simplewiki.database import RevisionedPage, Page -from simplewiki.actions import page_missing - +from .actions import page_missing +from .database import Page +from .database import RevisionedPage +from .utils import generate_template +from .utils import Pagination +from .utils import Response def page_index(request): @@ -20,19 +22,21 @@ def page_index(request): letters = {} for page in Page.query.order_by(Page.name): letters.setdefault(page.name.capitalize()[0], []).append(page) - return Response(generate_template('page_index.html', - letters=sorted(letters.items()) - )) + return Response( + generate_template("page_index.html", letters=sorted(letters.items())) + ) def recent_changes(request): """Display the recent changes.""" - page = max(1, request.args.get('page', type=int)) - query = RevisionedPage.query \ - .order_by(RevisionedPage.revision_id.desc()) - return Response(generate_template('recent_changes.html', - pagination=Pagination(query, 20, page, 'Special:Recent_Changes') - )) + page = max(1, request.args.get("page", type=int)) + query = RevisionedPage.query.order_by(RevisionedPage.revision_id.desc()) + return Response( + generate_template( + "recent_changes.html", + pagination=Pagination(query, 20, page, "Special:Recent_Changes"), + ) + ) def page_not_found(request, page_name): @@ -43,7 +47,4 @@ def page_not_found(request, page_name): return page_missing(request, page_name, True) -pages = { - 'Index': page_index, - 'Recent_Changes': recent_changes -} +pages = {"Index": page_index, "Recent_Changes": recent_changes} diff --git a/examples/simplewiki/templates/macros.xml b/examples/simplewiki/templates/macros.xml index 14bebd83..28ce06b9 100644 --- a/examples/simplewiki/templates/macros.xml +++ b/examples/simplewiki/templates/macros.xml @@ -1,6 +1,6 @@ <div xmlns="http://www.w3.org/1999/xhtml" xmlns:py="http://genshi.edgewall.org/" py:strip=""> - + <py:def function="render_pagination(pagination)"> <div class="pagination" py:if="pagination.pages > 1"> <py:choose test="pagination.has_previous"> diff --git a/examples/simplewiki/utils.py b/examples/simplewiki/utils.py index f6f11d29..66289e8b 100644 --- a/examples/simplewiki/utils.py +++ b/examples/simplewiki/utils.py @@ -9,20 +9,25 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import creoleparser from os import path + +import creoleparser from genshi import Stream from genshi.template import TemplateLoader -from werkzeug.local import Local, LocalManager -from werkzeug.urls import url_encode, url_quote +from werkzeug.local import Local +from werkzeug.local import LocalManager +from werkzeug.urls import url_encode +from werkzeug.urls import url_quote from werkzeug.utils import cached_property -from werkzeug.wrappers import BaseRequest, BaseResponse +from werkzeug.wrappers import BaseRequest +from werkzeug.wrappers import BaseResponse # calculate the path to the templates an create the template loader -TEMPLATE_PATH = path.join(path.dirname(__file__), 'templates') -template_loader = TemplateLoader(TEMPLATE_PATH, auto_reload=True, - variable_lookup='lenient') +TEMPLATE_PATH = path.join(path.dirname(__file__), "templates") +template_loader = TemplateLoader( + TEMPLATE_PATH, auto_reload=True, variable_lookup="lenient" +) # context locals. these two objects are use by the application to @@ -30,27 +35,25 @@ template_loader = TemplateLoader(TEMPLATE_PATH, auto_reload=True, # current thread and the current greenlet if there is greenlet support. local = Local() local_manager = LocalManager([local]) -request = local('request') -application = local('application') +request = local("request") +application = local("application") # create a new creole parser creole_parser = creoleparser.Parser( - dialect=creoleparser.create_dialect(creoleparser.creole10_base, - wiki_links_base_url='', + dialect=creoleparser.create_dialect( + creoleparser.creole10_base, + wiki_links_base_url="", wiki_links_path_func=lambda page_name: href(page_name), - wiki_links_space_char='_', - no_wiki_monospace=True + wiki_links_space_char="_", + no_wiki_monospace=True, ), - method='html' + method="html", ) def generate_template(template_name, **context): """Load and generate a template.""" - context.update( - href=href, - format_datetime=format_datetime - ) + context.update(href=href, format_datetime=format_datetime) return template_loader.load(template_name).generate(**context) @@ -64,17 +67,17 @@ def href(*args, **kw): Simple function for URL generation. Position arguments are used for the URL path and keyword arguments are used for the url parameters. """ - result = [(request.script_root if request else '') + '/'] + result = [(request.script_root if request else "") + "/"] for idx, arg in enumerate(args): - result.append(('/' if idx else '') + url_quote(arg)) + result.append(("/" if idx else "") + url_quote(arg)) if kw: - result.append('?' + url_encode(kw)) - return ''.join(result) + result.append("?" + url_encode(kw)) + return "".join(result) def format_datetime(obj): """Format a datetime object.""" - return obj.strftime('%Y-%m-%d %H:%M') + return obj.strftime("%Y-%m-%d %H:%M") class Request(BaseRequest): @@ -94,14 +97,14 @@ class Response(BaseResponse): to html. This makes it possible to switch to xhtml or html5 easily. """ - default_mimetype = 'text/html' + default_mimetype = "text/html" - def __init__(self, response=None, status=200, headers=None, mimetype=None, - content_type=None): + def __init__( + self, response=None, status=200, headers=None, mimetype=None, content_type=None + ): if isinstance(response, Stream): - response = response.render('html', encoding=None, doctype='html') - BaseResponse.__init__(self, response, status, headers, mimetype, - content_type) + response = response.render("html", encoding=None, doctype="html") + BaseResponse.__init__(self, response, status, headers, mimetype, content_type) class Pagination(object): @@ -118,8 +121,11 @@ class Pagination(object): @cached_property def entries(self): - return self.query.offset((self.page - 1) * self.per_page) \ - .limit(self.per_page).all() + return ( + self.query.offset((self.page - 1) * self.per_page) + .limit(self.per_page) + .all() + ) @property def has_previous(self): diff --git a/examples/upload.py b/examples/upload.py index 88af63aa..6f520b42 100644 --- a/examples/upload.py +++ b/examples/upload.py @@ -9,36 +9,39 @@ :license: BSD-3-Clause """ from werkzeug.serving import run_simple -from werkzeug.wrappers import BaseRequest, BaseResponse +from werkzeug.wrappers import BaseRequest +from werkzeug.wrappers import BaseResponse from werkzeug.wsgi import wrap_file def view_file(req): - if not 'uploaded_file' in req.files: - return BaseResponse('no file uploaded') - f = req.files['uploaded_file'] - return BaseResponse(wrap_file(req.environ, f), mimetype=f.content_type, - direct_passthrough=True) + if "uploaded_file" not in req.files: + return BaseResponse("no file uploaded") + f = req.files["uploaded_file"] + return BaseResponse( + wrap_file(req.environ, f), mimetype=f.content_type, direct_passthrough=True + ) def upload_file(req): - return BaseResponse(''' - <h1>Upload File</h1> - <form action="" method="post" enctype="multipart/form-data"> - <input type="file" name="uploaded_file"> - <input type="submit" value="Upload"> - </form> - ''', mimetype='text/html') + return BaseResponse( + """<h1>Upload File</h1> + <form action="" method="post" enctype="multipart/form-data"> + <input type="file" name="uploaded_file"> + <input type="submit" value="Upload"> + </form>""", + mimetype="text/html", + ) def application(environ, start_response): req = BaseRequest(environ) - if req.method == 'POST': + if req.method == "POST": resp = view_file(req) else: resp = upload_file(req) return resp(environ, start_response) -if __name__ == '__main__': - run_simple('localhost', 5000, application, use_debugger=True) +if __name__ == "__main__": + run_simple("localhost", 5000, application, use_debugger=True) diff --git a/examples/webpylike/example.py b/examples/webpylike/example.py index 7dc52a87..0bb6896d 100644 --- a/examples/webpylike/example.py +++ b/examples/webpylike/example.py @@ -8,23 +8,22 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -from webpylike.webpylike import WebPyApp, View, Response +from .webpylike import Response +from .webpylike import View +from .webpylike import WebPyApp -urls = ( - '/', 'index', - '/about', 'about' -) +urls = ("/", "index", "/about", "about") class index(View): def GET(self): - return Response('Hello World') + return Response("Hello World") class about(View): def GET(self): - return Response('This is the about page') + return Response("This is the about page") app = WebPyApp(urls, globals()) diff --git a/examples/webpylike/webpylike.py b/examples/webpylike/webpylike.py index 6669ec63..0a7325ca 100644 --- a/examples/webpylike/webpylike.py +++ b/examples/webpylike/webpylike.py @@ -11,9 +11,13 @@ :license: BSD-3-Clause """ import re -from werkzeug.wrappers import BaseRequest, BaseResponse -from werkzeug.exceptions import HTTPException, MethodNotAllowed, \ - NotImplemented, NotFound + +from werkzeug.exceptions import HTTPException +from werkzeug.exceptions import MethodNotAllowed +from werkzeug.exceptions import NotFound +from werkzeug.exceptions import NotImplemented +from werkzeug.wrappers import BaseRequest +from werkzeug.wrappers import BaseResponse class Request(BaseRequest): @@ -33,6 +37,7 @@ class View(object): def GET(self): raise MethodNotAllowed() + POST = DELETE = PUT = GET def HEAD(self): @@ -46,8 +51,9 @@ class WebPyApp(object): """ def __init__(self, urls, views): - self.urls = [(re.compile('^%s$' % urls[i]), urls[i + 1]) - for i in range(0, len(urls), 2)] + self.urls = [ + (re.compile("^%s$" % urls[i]), urls[i + 1]) for i in range(0, len(urls), 2) + ] self.views = views def __call__(self, environ, start_response): @@ -57,9 +63,8 @@ class WebPyApp(object): match = regex.match(req.path) if match is not None: view = self.views[view](self, req) - if req.method not in ('GET', 'HEAD', 'POST', - 'DELETE', 'PUT'): - raise NotImplemented() + if req.method not in ("GET", "HEAD", "POST", "DELETE", "PUT"): + raise NotImplemented() # noqa: F901 resp = getattr(view, req.method)(*match.groups()) break else: @@ -28,8 +28,27 @@ source = .tox/*/site-packages/werkzeug [flake8] -ignore = E126,E241,E272,E305,E402,E731,W503 -exclude=.tox,examples,docs -max-line-length=100 +# B = bugbear +# E = pycodestyle errors +# F = flake8 pyflakes +# W = pycodestyle warnings +# B9 = bugbear opinions +select = B, E, F, W, B9 +ignore = + # slice notation whitespace, invalid + E203 + # import at top, too many circular import fixes + E402 + # line length, handled by bugbear B950 + E501 + # bare except, handled by bugbear B001 + E722 + # bin op line break, invalid + W503 +# up to 88 allowed by bugbear B950 +max-line-length = 80 per-file-ignores = - src/werkzeug/wrappers/__init__.py: F401 + # __init__ modules export names + **/__init__.py: F401 + # LocalProxy assigns lambdas + src/werkzeug/local.py: E731 @@ -1,17 +1,17 @@ import io import re -from setuptools import find_packages, setup +from setuptools import find_packages +from setuptools import setup -with io.open('README.rst', 'rt', encoding='utf8') as f: +with io.open("README.rst", "rt", encoding="utf8") as f: readme = f.read() -with io.open('src/werkzeug/__init__.py', 'rt', encoding='utf8') as f: - version = re.search( - r'__version__ = \'(.*?)\'', f.read(), re.M).group(1) +with io.open("src/werkzeug/__init__.py", "rt", encoding="utf8") as f: + version = re.search(r'__version__ = "(.*?)"', f.read(), re.M).group(1) setup( - name='Werkzeug', + name="Werkzeug", version=version, url="https://palletsprojects.com/p/werkzeug/", project_urls={ @@ -19,53 +19,44 @@ setup( "Code": "https://github.com/pallets/werkzeug", "Issue tracker": "https://github.com/pallets/werkzeug/issues", }, - license='BSD-3-Clause', - author='Armin Ronacher', - author_email='armin.ronacher@active-4.com', + license="BSD-3-Clause", + author="Armin Ronacher", + author_email="armin.ronacher@active-4.com", maintainer="The Pallets Team", maintainer_email="contact@palletsprojects.com", - description='The comprehensive WSGI web application library.', + description="The comprehensive WSGI web application library.", long_description=readme, classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Environment :: Web Environment', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: BSD License', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', + "Development Status :: 5 - Production/Stable", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", - 'Topic :: Internet :: WWW/HTTP :: Dynamic Content', - 'Topic :: Internet :: WWW/HTTP :: WSGI', - 'Topic :: Internet :: WWW/HTTP :: WSGI :: Application', - 'Topic :: Internet :: WWW/HTTP :: WSGI :: Middleware', - 'Topic :: Software Development :: Libraries :: Application Frameworks', - 'Topic :: Software Development :: Libraries :: Python Modules', + "Topic :: Internet :: WWW/HTTP :: Dynamic Content", + "Topic :: Internet :: WWW/HTTP :: WSGI", + "Topic :: Internet :: WWW/HTTP :: WSGI :: Application", + "Topic :: Internet :: WWW/HTTP :: WSGI :: Middleware", + "Topic :: Software Development :: Libraries :: Application Frameworks", + "Topic :: Software Development :: Libraries :: Python Modules", ], packages=find_packages("src"), package_dir={"": "src"}, include_package_data=True, python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*", extras_require={ - 'watchdog': ['watchdog'], - 'termcolor': ['termcolor'], - 'dev': [ - 'pytest', - 'coverage', - 'tox', - 'sphinx', - 'pallets-sphinx-themes', - ], - 'docs': [ - 'sphinx', - 'pallets-sphinx-themes', - ] + "watchdog": ["watchdog"], + "termcolor": ["termcolor"], + "dev": ["pytest", "coverage", "tox", "sphinx", "pallets-sphinx-themes"], + "docs": ["sphinx", "pallets-sphinx-themes"], }, ) diff --git a/src/werkzeug/__init__.py b/src/werkzeug/__init__.py index f8d0d447..5a0c4e9e 100644 --- a/src/werkzeug/__init__.py +++ b/src/werkzeug/__init__.py @@ -14,12 +14,10 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -from types import ModuleType import sys +from types import ModuleType -from werkzeug._compat import iteritems - -__version__ = '0.15.dev' +__version__ = "0.15.dev" # This import magic raises concerns quite often which is why the implementation # and motivation is explained here in detail now. @@ -34,82 +32,153 @@ __version__ = '0.15.dev' # werkzeug package when imported from within. Attribute access to the werkzeug # module will then lazily import from the modules that implement the objects. - # import mapping to objects in other modules all_by_module = { - 'werkzeug.debug': ['DebuggedApplication'], - 'werkzeug.local': ['Local', 'LocalManager', 'LocalProxy', 'LocalStack', - 'release_local'], - 'werkzeug.serving': ['run_simple'], - 'werkzeug.test': ['Client', 'EnvironBuilder', 'create_environ', - 'run_wsgi_app'], - 'werkzeug.testapp': ['test_app'], - 'werkzeug.exceptions': ['abort', 'Aborter'], - 'werkzeug.urls': ['url_decode', 'url_encode', 'url_quote', - 'url_quote_plus', 'url_unquote', 'url_unquote_plus', - 'url_fix', 'Href', 'iri_to_uri', 'uri_to_iri'], - 'werkzeug.formparser': ['parse_form_data'], - 'werkzeug.utils': ['escape', 'environ_property', 'append_slash_redirect', - 'redirect', 'cached_property', 'import_string', - 'dump_cookie', 'parse_cookie', 'unescape', - 'format_string', 'find_modules', 'header_property', - 'html', 'xhtml', 'HTMLBuilder', 'validate_arguments', - 'ArgumentValidationError', 'bind_arguments', - 'secure_filename'], - 'werkzeug.wsgi': ['get_current_url', 'get_host', 'pop_path_info', - 'peek_path_info', - 'ClosingIterator', 'FileWrapper', - 'make_line_iter', 'LimitedStream', 'responder', - 'wrap_file', 'extract_path_info'], - 'werkzeug.datastructures': ['MultiDict', 'CombinedMultiDict', 'Headers', - 'EnvironHeaders', 'ImmutableList', - 'ImmutableDict', 'ImmutableMultiDict', - 'TypeConversionDict', - 'ImmutableTypeConversionDict', 'Accept', - 'MIMEAccept', 'CharsetAccept', - 'LanguageAccept', 'RequestCacheControl', - 'ResponseCacheControl', 'ETags', 'HeaderSet', - 'WWWAuthenticate', 'Authorization', - 'FileMultiDict', 'CallbackDict', 'FileStorage', - 'OrderedMultiDict', 'ImmutableOrderedMultiDict' - ], - 'werkzeug.useragents': ['UserAgent'], - 'werkzeug.http': ['parse_etags', 'parse_date', 'http_date', 'cookie_date', - 'parse_cache_control_header', 'is_resource_modified', - 'parse_accept_header', 'parse_set_header', 'quote_etag', - 'unquote_etag', 'generate_etag', 'dump_header', - 'parse_list_header', 'parse_dict_header', - 'parse_authorization_header', - 'parse_www_authenticate_header', 'remove_entity_headers', - 'is_entity_header', 'remove_hop_by_hop_headers', - 'parse_options_header', 'dump_options_header', - 'is_hop_by_hop_header', 'unquote_header_value', - 'quote_header_value', 'HTTP_STATUS_CODES'], - 'werkzeug.wrappers': ['BaseResponse', 'BaseRequest', 'Request', 'Response', - 'AcceptMixin', 'ETagRequestMixin', - 'ETagResponseMixin', 'ResponseStreamMixin', - 'CommonResponseDescriptorsMixin', 'UserAgentMixin', - 'AuthorizationMixin', 'WWWAuthenticateMixin', - 'CommonRequestDescriptorsMixin'], + "werkzeug.debug": ["DebuggedApplication"], + "werkzeug.local": [ + "Local", + "LocalManager", + "LocalProxy", + "LocalStack", + "release_local", + ], + "werkzeug.serving": ["run_simple"], + "werkzeug.test": ["Client", "EnvironBuilder", "create_environ", "run_wsgi_app"], + "werkzeug.testapp": ["test_app"], + "werkzeug.exceptions": ["abort", "Aborter"], + "werkzeug.urls": [ + "url_decode", + "url_encode", + "url_quote", + "url_quote_plus", + "url_unquote", + "url_unquote_plus", + "url_fix", + "Href", + "iri_to_uri", + "uri_to_iri", + ], + "werkzeug.formparser": ["parse_form_data"], + "werkzeug.utils": [ + "escape", + "environ_property", + "append_slash_redirect", + "redirect", + "cached_property", + "import_string", + "dump_cookie", + "parse_cookie", + "unescape", + "format_string", + "find_modules", + "header_property", + "html", + "xhtml", + "HTMLBuilder", + "validate_arguments", + "ArgumentValidationError", + "bind_arguments", + "secure_filename", + ], + "werkzeug.wsgi": [ + "get_current_url", + "get_host", + "pop_path_info", + "peek_path_info", + "ClosingIterator", + "FileWrapper", + "make_line_iter", + "LimitedStream", + "responder", + "wrap_file", + "extract_path_info", + ], + "werkzeug.datastructures": [ + "MultiDict", + "CombinedMultiDict", + "Headers", + "EnvironHeaders", + "ImmutableList", + "ImmutableDict", + "ImmutableMultiDict", + "TypeConversionDict", + "ImmutableTypeConversionDict", + "Accept", + "MIMEAccept", + "CharsetAccept", + "LanguageAccept", + "RequestCacheControl", + "ResponseCacheControl", + "ETags", + "HeaderSet", + "WWWAuthenticate", + "Authorization", + "FileMultiDict", + "CallbackDict", + "FileStorage", + "OrderedMultiDict", + "ImmutableOrderedMultiDict", + ], + "werkzeug.useragents": ["UserAgent"], + "werkzeug.http": [ + "parse_etags", + "parse_date", + "http_date", + "cookie_date", + "parse_cache_control_header", + "is_resource_modified", + "parse_accept_header", + "parse_set_header", + "quote_etag", + "unquote_etag", + "generate_etag", + "dump_header", + "parse_list_header", + "parse_dict_header", + "parse_authorization_header", + "parse_www_authenticate_header", + "remove_entity_headers", + "is_entity_header", + "remove_hop_by_hop_headers", + "parse_options_header", + "dump_options_header", + "is_hop_by_hop_header", + "unquote_header_value", + "quote_header_value", + "HTTP_STATUS_CODES", + ], + "werkzeug.wrappers": [ + "BaseResponse", + "BaseRequest", + "Request", + "Response", + "AcceptMixin", + "ETagRequestMixin", + "ETagResponseMixin", + "ResponseStreamMixin", + "CommonResponseDescriptorsMixin", + "UserAgentMixin", + "AuthorizationMixin", + "WWWAuthenticateMixin", + "CommonRequestDescriptorsMixin", + ], "werkzeug.middleware.dispatcher": ["DispatcherMiddleware"], "werkzeug.middleware.shared_data": ["SharedDataMiddleware"], - 'werkzeug.security': ['generate_password_hash', 'check_password_hash'], + "werkzeug.security": ["generate_password_hash", "check_password_hash"], # the undocumented easteregg ;-) - 'werkzeug._internal': ['_easteregg'] + "werkzeug._internal": ["_easteregg"], } # modules that should be imported when accessed as attributes of werkzeug -attribute_modules = frozenset(['exceptions', 'routing']) - +attribute_modules = frozenset(["exceptions", "routing"]) object_origins = {} -for module, items in iteritems(all_by_module): +for module, items in all_by_module.items(): for item in items: object_origins[item] = module class module(ModuleType): - """Automatically import objects from the modules.""" def __getattr__(self, name): @@ -119,34 +188,46 @@ class module(ModuleType): setattr(self, extra_name, getattr(module, extra_name)) return getattr(module, name) elif name in attribute_modules: - __import__('werkzeug.' + name) + __import__("werkzeug." + name) return ModuleType.__getattribute__(self, name) def __dir__(self): """Just show what we want to show.""" result = list(new_module.__all__) - result.extend(('__file__', '__doc__', '__all__', - '__docformat__', '__name__', '__path__', - '__package__', '__version__')) + result.extend( + ( + "__file__", + "__doc__", + "__all__", + "__docformat__", + "__name__", + "__path__", + "__package__", + "__version__", + ) + ) return result + # keep a reference to this module so that it's not garbage collected -old_module = sys.modules['werkzeug'] +old_module = sys.modules["werkzeug"] # setup the new module and patch it into the dict of loaded modules -new_module = sys.modules['werkzeug'] = module('werkzeug') -new_module.__dict__.update({ - '__file__': __file__, - '__package__': 'werkzeug', - '__path__': __path__, - '__doc__': __doc__, - '__version__': __version__, - '__all__': tuple(object_origins) + tuple(attribute_modules), - '__docformat__': 'restructuredtext en' -}) +new_module = sys.modules["werkzeug"] = module("werkzeug") +new_module.__dict__.update( + { + "__file__": __file__, + "__package__": "werkzeug", + "__path__": __path__, + "__doc__": __doc__, + "__version__": __version__, + "__all__": tuple(object_origins) + tuple(attribute_modules), + "__docformat__": "restructuredtext en", + } +) # Due to bootstrapping issues we need to import exceptions here. # Don't ask :-( -__import__('werkzeug.exceptions') +__import__("werkzeug.exceptions") diff --git a/src/werkzeug/_compat.py b/src/werkzeug/_compat.py index f9c5b343..1097983e 100644 --- a/src/werkzeug/_compat.py +++ b/src/werkzeug/_compat.py @@ -1,10 +1,8 @@ # flake8: noqa # This whole file is full of lint errors -import codecs -import sys -import operator import functools -import warnings +import operator +import sys try: import builtins @@ -13,7 +11,7 @@ except ImportError: PY2 = sys.version_info[0] == 2 -WIN = sys.platform.startswith('win') +WIN = sys.platform.startswith("win") _identity = lambda x: x @@ -35,15 +33,19 @@ if PY2: import collections as collections_abc - exec('def reraise(tp, value, tb=None):\n raise tp, value, tb') + exec("def reraise(tp, value, tb=None):\n raise tp, value, tb") def fix_tuple_repr(obj): def __repr__(self): cls = self.__class__ - return '%s(%s)' % (cls.__name__, ', '.join( - '%s=%r' % (field, self[index]) - for index, field in enumerate(cls._fields) - )) + return "%s(%s)" % ( + cls.__name__, + ", ".join( + "%s=%r" % (field, self[index]) + for index, field in enumerate(cls._fields) + ), + ) + obj.__repr__ = __repr__ return obj @@ -54,12 +56,13 @@ if PY2: def implements_to_string(cls): cls.__unicode__ = cls.__str__ - cls.__str__ = lambda x: x.__unicode__().encode('utf-8') + cls.__str__ = lambda x: x.__unicode__().encode("utf-8") return cls def native_string_result(func): def wrapper(*args, **kwargs): - return func(*args, **kwargs).encode('utf-8') + return func(*args, **kwargs).encode("utf-8") + return functools.update_wrapper(wrapper, func) def implements_bool(cls): @@ -68,10 +71,12 @@ if PY2: return cls from itertools import imap, izip, ifilter + range_type = xrange from StringIO import StringIO from cStringIO import StringIO as BytesIO + NativeStringIO = BytesIO def make_literal_wrapper(reference): @@ -96,33 +101,34 @@ if PY2: wsgi_get_bytes = _identity - def wsgi_decoding_dance(s, charset='utf-8', errors='replace'): + def wsgi_decoding_dance(s, charset="utf-8", errors="replace"): return s.decode(charset, errors) - def wsgi_encoding_dance(s, charset='utf-8', errors='replace'): + def wsgi_encoding_dance(s, charset="utf-8", errors="replace"): if isinstance(s, bytes): return s return s.encode(charset, errors) - def to_bytes(x, charset=sys.getdefaultencoding(), errors='strict'): + def to_bytes(x, charset=sys.getdefaultencoding(), errors="strict"): if x is None: return None if isinstance(x, (bytes, bytearray, buffer)): return bytes(x) if isinstance(x, unicode): return x.encode(charset, errors) - raise TypeError('Expected bytes') + raise TypeError("Expected bytes") - def to_native(x, charset=sys.getdefaultencoding(), errors='strict'): + def to_native(x, charset=sys.getdefaultencoding(), errors="strict"): if x is None or isinstance(x, str): return x return x.encode(charset, errors) + else: unichr = chr text_type = str - string_types = (str, ) - integer_types = (int, ) + string_types = (str,) + integer_types = (int,) iterkeys = lambda d, *args, **kwargs: iter(d.keys(*args, **kwargs)) itervalues = lambda d, *args, **kwargs: iter(d.values(*args, **kwargs)) @@ -131,7 +137,7 @@ else: iterlists = lambda d, *args, **kwargs: iter(d.lists(*args, **kwargs)) iterlistvalues = lambda d, *args, **kwargs: iter(d.listvalues(*args, **kwargs)) - int_to_byte = operator.methodcaller('to_bytes', 1, 'big') + int_to_byte = operator.methodcaller("to_bytes", 1, "big") iter_bytes = functools.partial(map, int_to_byte) import collections.abc as collections_abc @@ -152,9 +158,10 @@ else: range_type = range from io import StringIO, BytesIO + NativeStringIO = StringIO - _latin1_encode = operator.methodcaller('encode', 'latin1') + _latin1_encode = operator.methodcaller("encode", "latin1") def make_literal_wrapper(reference): if isinstance(reference, text_type): @@ -169,38 +176,40 @@ else: is_text = isinstance(next(tupiter, None), text_type) for arg in tupiter: if isinstance(arg, text_type) != is_text: - raise TypeError('Cannot mix str and bytes arguments (got %s)' - % repr(tup)) + raise TypeError( + "Cannot mix str and bytes arguments (got %s)" % repr(tup) + ) return tup try_coerce_native = _identity wsgi_get_bytes = _latin1_encode - def wsgi_decoding_dance(s, charset='utf-8', errors='replace'): - return s.encode('latin1').decode(charset, errors) + def wsgi_decoding_dance(s, charset="utf-8", errors="replace"): + return s.encode("latin1").decode(charset, errors) - def wsgi_encoding_dance(s, charset='utf-8', errors='replace'): + def wsgi_encoding_dance(s, charset="utf-8", errors="replace"): if isinstance(s, text_type): s = s.encode(charset) - return s.decode('latin1', errors) + return s.decode("latin1", errors) - def to_bytes(x, charset=sys.getdefaultencoding(), errors='strict'): + def to_bytes(x, charset=sys.getdefaultencoding(), errors="strict"): if x is None: return None if isinstance(x, (bytes, bytearray, memoryview)): # noqa return bytes(x) if isinstance(x, str): return x.encode(charset, errors) - raise TypeError('Expected bytes') + raise TypeError("Expected bytes") - def to_native(x, charset=sys.getdefaultencoding(), errors='strict'): + def to_native(x, charset=sys.getdefaultencoding(), errors="strict"): if x is None or isinstance(x, str): return x return x.decode(charset, errors) -def to_unicode(x, charset=sys.getdefaultencoding(), errors='strict', - allow_none_charset=False): +def to_unicode( + x, charset=sys.getdefaultencoding(), errors="strict", allow_none_charset=False +): if x is None: return None if not isinstance(x, bytes): diff --git a/src/werkzeug/_internal.py b/src/werkzeug/_internal.py index 36084438..d8b83363 100644 --- a/src/werkzeug/_internal.py +++ b/src/werkzeug/_internal.py @@ -8,41 +8,46 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ +import inspect import re import string -import inspect -from weakref import WeakKeyDictionary -from datetime import datetime, date +from datetime import date +from datetime import datetime from itertools import chain +from weakref import WeakKeyDictionary -from werkzeug._compat import iter_bytes, text_type, int_to_byte, range_type, \ - integer_types +from ._compat import int_to_byte +from ._compat import integer_types +from ._compat import iter_bytes +from ._compat import range_type +from ._compat import text_type _logger = None _signature_cache = WeakKeyDictionary() _epoch_ord = date(1970, 1, 1).toordinal() -_cookie_params = set((b'expires', b'path', b'comment', - b'max-age', b'secure', b'httponly', - b'version')) -_legal_cookie_chars = (string.ascii_letters - + string.digits - + u"/=!#$%&'*+-.^_`|~:").encode('ascii') - -_cookie_quoting_map = { - b',': b'\\054', - b';': b'\\073', - b'"': b'\\"', - b'\\': b'\\\\', +_cookie_params = { + b"expires", + b"path", + b"comment", + b"max-age", + b"secure", + b"httponly", + b"version", } -for _i in chain(range_type(32), range_type(127, 256)): - _cookie_quoting_map[int_to_byte(_i)] = ('\\%03o' % _i).encode('latin1') +_legal_cookie_chars = ( + string.ascii_letters + string.digits + u"/=!#$%&'*+-.^_`|~:" +).encode("ascii") +_cookie_quoting_map = {b",": b"\\054", b";": b"\\073", b'"': b'\\"', b"\\": b"\\\\"} +for _i in chain(range_type(32), range_type(127, 256)): + _cookie_quoting_map[int_to_byte(_i)] = ("\\%03o" % _i).encode("latin1") -_octal_re = re.compile(br'\\[0-3][0-7][0-7]') -_quote_re = re.compile(br'[\\].') -_legal_cookie_chars_re = br'[\w\d!#%&\'~_`><@,:/\$\*\+\-\.\^\|\)\(\?\}\{\=]' -_cookie_re = re.compile(br""" +_octal_re = re.compile(br"\\[0-3][0-7][0-7]") +_quote_re = re.compile(br"[\\].") +_legal_cookie_chars_re = br"[\w\d!#%&\'~_`><@,:/\$\*\+\-\.\^\|\)\(\?\}\{\=]" +_cookie_re = re.compile( + br""" (?P<key>[^=;]+) (?:\s*=\s* (?P<val> @@ -51,24 +56,27 @@ _cookie_re = re.compile(br""" ) )? \s*; -""", flags=re.VERBOSE) +""", + flags=re.VERBOSE, +) class _Missing(object): - def __repr__(self): - return 'no value' + return "no value" def __reduce__(self): - return '_missing' + return "_missing" + _missing = _Missing() def _get_environ(obj): - env = getattr(obj, 'environ', obj) - assert isinstance(env, dict), \ - '%r is not a WSGI environment (has to be a dict)' % type(obj).__name__ + env = getattr(obj, "environ", obj) + assert isinstance(env, dict), ( + "%r is not a WSGI environment (has to be a dict)" % type(obj).__name__ + ) return env @@ -77,7 +85,8 @@ def _log(type, message, *args, **kwargs): global _logger if _logger is None: import logging - _logger = logging.getLogger('werkzeug') + + _logger = logging.getLogger("werkzeug") if _logger.level == logging.NOTSET: _logger.setLevel(logging.INFO) # Only set up a default log handler if the @@ -90,7 +99,7 @@ def _log(type, message, *args, **kwargs): def _parse_signature(func): """Return a signature object for the function.""" - if hasattr(func, 'im_func'): + if hasattr(func, "im_func"): func = func.im_func # if we have a cached validator for this function, return it @@ -99,7 +108,7 @@ def _parse_signature(func): return parse # inspect the function signature and collect all the information - if hasattr(inspect, 'getfullargspec'): + if hasattr(inspect, "getfullargspec"): tup = inspect.getfullargspec(func) else: tup = inspect.getargspec(func) @@ -109,8 +118,9 @@ def _parse_signature(func): arguments = [] for idx, name in enumerate(positional): if isinstance(name, list): - raise TypeError('cannot parse functions that unpack tuples ' - 'in the function signature') + raise TypeError( + "cannot parse functions that unpack tuples in the function signature" + ) try: default = defaults[idx - arg_count] except IndexError: @@ -150,8 +160,17 @@ def _parse_signature(func): extra.update(kwargs) kwargs = {} - return new_args, kwargs, missing, extra, extra_positional, \ - arguments, vararg_var, kwarg_var + return ( + new_args, + kwargs, + missing, + extra, + extra_positional, + arguments, + vararg_var, + kwarg_var, + ) + _signature_cache[func] = parse return parse @@ -173,12 +192,19 @@ def _date_to_unix(arg): class _DictAccessorProperty(object): - """Baseclass for `environ_property` and `header_property`.""" + read_only = False - def __init__(self, name, default=None, load_func=None, dump_func=None, - read_only=None, doc=None): + def __init__( + self, + name, + default=None, + load_func=None, + dump_func=None, + read_only=None, + doc=None, + ): self.name = name self.default = default self.load_func = load_func @@ -203,21 +229,18 @@ class _DictAccessorProperty(object): def __set__(self, obj, value): if self.read_only: - raise AttributeError('read only property') + raise AttributeError("read only property") if self.dump_func is not None: value = self.dump_func(value) self.lookup(obj)[self.name] = value def __delete__(self, obj): if self.read_only: - raise AttributeError('read only property') + raise AttributeError("read only property") self.lookup(obj).pop(self.name, None) def __repr__(self): - return '<%s %s>' % ( - self.__class__.__name__, - self.name - ) + return "<%s %s>" % (self.__class__.__name__, self.name) def _cookie_quote(b): @@ -263,11 +286,11 @@ def _cookie_unquote(b): k = q_match.start(0) if q_match and (not o_match or k < j): _push(b[i:k]) - _push(b[k + 1:k + 2]) + _push(b[k + 1 : k + 2]) i = k + 2 else: _push(b[i:j]) - rv.append(int(b[j + 1:j + 4], 8)) + rv.append(int(b[j + 1 : j + 4], 8)) i = j + 4 return bytes(rv) @@ -279,12 +302,12 @@ def _cookie_parse_impl(b): n = len(b) while i < n: - match = _cookie_re.search(b + b';', i) + match = _cookie_re.search(b + b";", i) if not match: break - key = match.group('key').strip() - value = match.group('val') or b'' + key = match.group("key").strip() + value = match.group("val") or b"" i = match.end(0) # Ignore parameters. We have no interest in them. @@ -295,20 +318,20 @@ def _cookie_parse_impl(b): def _encode_idna(domain): # If we're given bytes, make sure they fit into ASCII if not isinstance(domain, text_type): - domain.decode('ascii') + domain.decode("ascii") return domain # Otherwise check if it's already ascii, then return try: - return domain.encode('ascii') + return domain.encode("ascii") except UnicodeError: pass # Otherwise encode each part separately - parts = domain.split('.') + parts = domain.split(".") for idx, part in enumerate(parts): - parts[idx] = part.encode('idna') - return b'.'.join(parts) + parts[idx] = part.encode("idna") + return b".".join(parts) def _decode_idna(domain): @@ -317,47 +340,54 @@ def _decode_idna(domain): # unicode error, then we already have a decoded idna domain if isinstance(domain, text_type): try: - domain = domain.encode('ascii') + domain = domain.encode("ascii") except UnicodeError: return domain # Decode each part separately. If a part fails, try to # decode it with ascii and silently ignore errors. This makes # most sense because the idna codec does not have error handling - parts = domain.split(b'.') + parts = domain.split(b".") for idx, part in enumerate(parts): try: - parts[idx] = part.decode('idna') + parts[idx] = part.decode("idna") except UnicodeError: - parts[idx] = part.decode('ascii', 'ignore') + parts[idx] = part.decode("ascii", "ignore") - return '.'.join(parts) + return ".".join(parts) def _make_cookie_domain(domain): if domain is None: return None domain = _encode_idna(domain) - if b':' in domain: - domain = domain.split(b':', 1)[0] - if b'.' in domain: + if b":" in domain: + domain = domain.split(b":", 1)[0] + if b"." in domain: return domain raise ValueError( - 'Setting \'domain\' for a cookie on a server running locally (ex: ' - 'localhost) is not supported by complying browsers. You should ' - 'have something like: \'127.0.0.1 localhost dev.localhost\' on ' - 'your hosts file and then point your server to run on ' - '\'dev.localhost\' and also set \'domain\' for \'dev.localhost\'' + "Setting 'domain' for a cookie on a server running locally (ex: " + "localhost) is not supported by complying browsers. You should " + "have something like: '127.0.0.1 localhost dev.localhost' on " + "your hosts file and then point your server to run on " + "'dev.localhost' and also set 'domain' for 'dev.localhost'" ) def _easteregg(app=None): """Like the name says. But who knows how it works?""" + def bzzzzzzz(gyver): import base64 import zlib - return zlib.decompress(base64.b64decode(gyver)).decode('ascii') - gyver = u'\n'.join([x + (77 - len(x)) * u' ' for x in bzzzzzzz(b''' + + return zlib.decompress(base64.b64decode(gyver)).decode("ascii") + + gyver = u"\n".join( + [ + x + (77 - len(x)) * u" " + for x in bzzzzzzz( + b""" eJyFlzuOJDkMRP06xRjymKgDJCDQStBYT8BCgK4gTwfQ2fcFs2a2FzvZk+hvlcRvRJD148efHt9m 9Xz94dRY5hGt1nrYcXx7us9qlcP9HHNh28rz8dZj+q4rynVFFPdlY4zH873NKCexrDM6zxxRymzz 4QIxzK4bth1PV7+uHn6WXZ5C4ka/+prFzx3zWLMHAVZb8RRUxtFXI5DTQ2n3Hi2sNI+HK43AOWSY @@ -388,16 +418,22 @@ p1qXK3Du2mnr5INXmT/78KI12n11EFBkJHHp0wJyLe9MvPNUGYsf+170maayRoy2lURGHAIapSpQ krEDuNoJCHNlZYhKpvw4mspVWxqo415n8cD62N9+EfHrAvqQnINStetek7RY2Urv8nxsnGaZfRr/ nhXbJ6m/yl1LzYqscDZA9QHLNbdaSTTr+kFg3bC0iYbX/eQy0Bv3h4B50/SGYzKAXkCeOLI3bcAt mj2Z/FM1vQWgDynsRwNvrWnJHlespkrp8+vO1jNaibm+PhqXPPv30YwDZ6jApe3wUjFQobghvW9p -7f2zLkGNv8b191cD/3vs9Q833z8t''').splitlines()]) +7f2zLkGNv8b191cD/3vs9Q833z8t""" + ).splitlines() + ] + ) def easteregged(environ, start_response): def injecting_start_response(status, headers, exc_info=None): - headers.append(('X-Powered-By', 'Werkzeug')) + headers.append(("X-Powered-By", "Werkzeug")) return start_response(status, headers, exc_info) - if app is not None and environ.get('QUERY_STRING') != 'macgybarchakku': + + if app is not None and environ.get("QUERY_STRING") != "macgybarchakku": return app(environ, injecting_start_response) - injecting_start_response('200 OK', [('Content-Type', 'text/html')]) - return [(u''' + injecting_start_response("200 OK", [("Content-Type", "text/html")]) + return [ + ( + u""" <!DOCTYPE html> <html> <head> @@ -415,5 +451,9 @@ mj2Z/FM1vQWgDynsRwNvrWnJHlespkrp8+vO1jNaibm+PhqXPPv30YwDZ6jApe3wUjFQobghvW9p <p>the Swiss Army knife of Python web development.</p> <pre>%s\n\n\n</pre> </body> -</html>''' % gyver).encode('latin1')] +</html>""" + % gyver + ).encode("latin1") + ] + return easteregged diff --git a/src/werkzeug/_reloader.py b/src/werkzeug/_reloader.py index 2b21e146..f06a63d5 100644 --- a/src/werkzeug/_reloader.py +++ b/src/werkzeug/_reloader.py @@ -1,12 +1,14 @@ import os -import sys -import time import subprocess +import sys import threading +import time from itertools import chain -from werkzeug._internal import _log -from werkzeug._compat import PY2, iteritems, text_type +from ._compat import iteritems +from ._compat import PY2 +from ._compat import text_type +from ._internal import _log def _iter_module_files(): @@ -19,10 +21,11 @@ def _iter_module_files(): for module in list(sys.modules.values()): if module is None: continue - filename = getattr(module, '__file__', None) + filename = getattr(module, "__file__", None) if filename: - if os.path.isdir(filename) and \ - os.path.exists(os.path.join(filename, "__init__.py")): + if os.path.isdir(filename) and os.path.exists( + os.path.join(filename, "__init__.py") + ): filename = os.path.join(filename, "__init__.py") old = None @@ -32,22 +35,23 @@ def _iter_module_files(): if filename == old: break else: - if filename[-4:] in ('.pyc', '.pyo'): + if filename[-4:] in (".pyc", ".pyo"): filename = filename[:-1] yield filename def _find_observable_paths(extra_files=None): """Finds all paths that should be observed.""" - rv = set(os.path.dirname(os.path.abspath(x)) - if os.path.isfile(x) else os.path.abspath(x) - for x in sys.path) + rv = set( + os.path.dirname(os.path.abspath(x)) if os.path.isfile(x) else os.path.abspath(x) + for x in sys.path + ) for filename in extra_files or (): rv.add(os.path.dirname(os.path.abspath(filename))) for module in list(sys.modules.values()): - fn = getattr(module, '__file__', None) + fn = getattr(module, "__file__", None) if fn is None: continue fn = os.path.abspath(fn) @@ -125,7 +129,8 @@ def _find_common_roots(paths): for prefix, child in iteritems(node): _walk(child, path + (prefix,)) if not node: - rv.add('/'.join(path)) + rv.add("/".join(path)) + _walk(root, ()) return rv @@ -139,8 +144,7 @@ class ReloaderLoop(object): _sleep = staticmethod(time.sleep) def __init__(self, extra_files=None, interval=1): - self.extra_files = set(os.path.abspath(x) - for x in extra_files or ()) + self.extra_files = set(os.path.abspath(x) for x in extra_files or ()) self.interval = interval def run(self): @@ -151,26 +155,25 @@ class ReloaderLoop(object): but running the reloader thread. """ while 1: - _log('info', ' * Restarting with %s' % self.name) + _log("info", " * Restarting with %s" % self.name) args = _get_args_for_reloading() # a weird bug on windows. sometimes unicode strings end up in the # environment and subprocess.call does not like this, encode them # to latin1 and continue. - if os.name == 'nt' and PY2: + if os.name == "nt" and PY2: new_environ = {} for key, value in iteritems(os.environ): if isinstance(key, text_type): - key = key.encode('iso-8859-1') + key = key.encode("iso-8859-1") if isinstance(value, text_type): - value = value.encode('iso-8859-1') + value = value.encode("iso-8859-1") new_environ[key] = value else: new_environ = os.environ.copy() - new_environ['WERKZEUG_RUN_MAIN'] = 'true' - exit_code = subprocess.call(args, env=new_environ, - close_fds=False) + new_environ["WERKZEUG_RUN_MAIN"] = "true" + exit_code = subprocess.call(args, env=new_environ, close_fds=False) if exit_code != 3: return exit_code @@ -180,17 +183,16 @@ class ReloaderLoop(object): def log_reload(self, filename): filename = os.path.abspath(filename) - _log('info', ' * Detected change in %r, reloading' % filename) + _log("info", " * Detected change in %r, reloading" % filename) class StatReloaderLoop(ReloaderLoop): - name = 'stat' + name = "stat" def run(self): mtimes = {} while 1: - for filename in chain(_iter_module_files(), - self.extra_files): + for filename in chain(_iter_module_files(), self.extra_files): try: mtime = os.stat(filename).st_mtime except OSError: @@ -206,11 +208,11 @@ class StatReloaderLoop(ReloaderLoop): class WatchdogReloaderLoop(ReloaderLoop): - def __init__(self, *args, **kwargs): ReloaderLoop.__init__(self, *args, **kwargs) from watchdog.observers import Observer from watchdog.events import FileSystemEventHandler + self.observable_paths = set() def _check_modification(filename): @@ -218,11 +220,10 @@ class WatchdogReloaderLoop(ReloaderLoop): self.trigger_reload(filename) dirname = os.path.dirname(filename) if dirname.startswith(tuple(self.observable_paths)): - if filename.endswith(('.pyc', '.pyo', '.py')): + if filename.endswith((".pyc", ".pyo", ".py")): self.trigger_reload(filename) class _CustomHandler(FileSystemEventHandler): - def on_created(self, event): _check_modification(event.src_path) @@ -237,9 +238,9 @@ class WatchdogReloaderLoop(ReloaderLoop): _check_modification(event.src_path) reloader_name = Observer.__name__.lower() - if reloader_name.endswith('observer'): + if reloader_name.endswith("observer"): reloader_name = reloader_name[:-8] - reloader_name += ' reloader' + reloader_name += " reloader" self.name = reloader_name @@ -267,7 +268,8 @@ class WatchdogReloaderLoop(ReloaderLoop): if path not in watches: try: watches[path] = observer.schedule( - self.event_handler, path, recursive=True) + self.event_handler, path, recursive=True + ) except OSError: # Clear this path from list of watches We don't want # the same error message showing again in the next @@ -287,17 +289,14 @@ class WatchdogReloaderLoop(ReloaderLoop): sys.exit(3) -reloader_loops = { - 'stat': StatReloaderLoop, - 'watchdog': WatchdogReloaderLoop, -} +reloader_loops = {"stat": StatReloaderLoop, "watchdog": WatchdogReloaderLoop} try: - __import__('watchdog.observers') + __import__("watchdog.observers") except ImportError: - reloader_loops['auto'] = reloader_loops['stat'] + reloader_loops["auto"] = reloader_loops["stat"] else: - reloader_loops['auto'] = reloader_loops['watchdog'] + reloader_loops["auto"] = reloader_loops["watchdog"] def ensure_echo_on(): @@ -316,14 +315,14 @@ def ensure_echo_on(): termios.tcsetattr(sys.stdin, termios.TCSANOW, attributes) -def run_with_reloader(main_func, extra_files=None, interval=1, - reloader_type='auto'): +def run_with_reloader(main_func, extra_files=None, interval=1, reloader_type="auto"): """Run the given function in an independent python interpreter.""" import signal + reloader = reloader_loops[reloader_type](extra_files, interval) signal.signal(signal.SIGTERM, lambda *args: sys.exit(0)) try: - if os.environ.get('WERKZEUG_RUN_MAIN') == 'true': + if os.environ.get("WERKZEUG_RUN_MAIN") == "true": ensure_echo_on() t = threading.Thread(target=main_func, args=()) t.setDaemon(True) diff --git a/src/werkzeug/contrib/atom.py b/src/werkzeug/contrib/atom.py index 5c631f43..d079d2bf 100644 --- a/src/werkzeug/contrib/atom.py +++ b/src/werkzeug/contrib/atom.py @@ -23,9 +23,11 @@ """ import warnings from datetime import datetime -from werkzeug.utils import escape -from werkzeug.wrappers import BaseResponse -from werkzeug._compat import implements_to_string, string_types + +from .._compat import implements_to_string +from .._compat import string_types +from ..utils import escape +from ..wrappers import BaseResponse warnings.warn( "'werkzeug.contrib.atom' is deprecated as of version 0.15 and will" @@ -34,18 +36,21 @@ warnings.warn( stacklevel=2, ) -XHTML_NAMESPACE = 'http://www.w3.org/1999/xhtml' +XHTML_NAMESPACE = "http://www.w3.org/1999/xhtml" def _make_text_block(name, content, content_type=None): """Helper function for the builder that creates an XML text block.""" - if content_type == 'xhtml': - return u'<%s type="xhtml"><div xmlns="%s">%s</div></%s>\n' % \ - (name, XHTML_NAMESPACE, content, name) + if content_type == "xhtml": + return u'<%s type="xhtml"><div xmlns="%s">%s</div></%s>\n' % ( + name, + XHTML_NAMESPACE, + content, + name, + ) if not content_type: - return u'<%s>%s</%s>\n' % (name, escape(content), name) - return u'<%s type="%s">%s</%s>\n' % (name, content_type, - escape(content), name) + return u"<%s>%s</%s>\n" % (name, escape(content), name) + return u'<%s type="%s">%s</%s>\n' % (name, content_type, escape(content), name) def format_iso8601(obj): @@ -53,7 +58,7 @@ def format_iso8601(obj): iso8601 = obj.isoformat() if obj.tzinfo: return iso8601 - return iso8601 + 'Z' + return iso8601 + "Z" @implements_to_string @@ -106,42 +111,43 @@ class AtomFeed(object): Everywhere where a list is demanded, any iterable can be used. """ - default_generator = ('Werkzeug', None, None) + default_generator = ("Werkzeug", None, None) def __init__(self, title=None, entries=None, **kwargs): self.title = title - self.title_type = kwargs.get('title_type', 'text') - self.url = kwargs.get('url') - self.feed_url = kwargs.get('feed_url', self.url) - self.id = kwargs.get('id', self.feed_url) - self.updated = kwargs.get('updated') - self.author = kwargs.get('author', ()) - self.icon = kwargs.get('icon') - self.logo = kwargs.get('logo') - self.rights = kwargs.get('rights') - self.rights_type = kwargs.get('rights_type') - self.subtitle = kwargs.get('subtitle') - self.subtitle_type = kwargs.get('subtitle_type', 'text') - self.generator = kwargs.get('generator') + self.title_type = kwargs.get("title_type", "text") + self.url = kwargs.get("url") + self.feed_url = kwargs.get("feed_url", self.url) + self.id = kwargs.get("id", self.feed_url) + self.updated = kwargs.get("updated") + self.author = kwargs.get("author", ()) + self.icon = kwargs.get("icon") + self.logo = kwargs.get("logo") + self.rights = kwargs.get("rights") + self.rights_type = kwargs.get("rights_type") + self.subtitle = kwargs.get("subtitle") + self.subtitle_type = kwargs.get("subtitle_type", "text") + self.generator = kwargs.get("generator") if self.generator is None: self.generator = self.default_generator - self.links = kwargs.get('links', []) + self.links = kwargs.get("links", []) self.entries = list(entries) if entries else [] - if not hasattr(self.author, '__iter__') \ - or isinstance(self.author, string_types + (dict,)): + if not hasattr(self.author, "__iter__") or isinstance( + self.author, string_types + (dict,) + ): self.author = [self.author] for i, author in enumerate(self.author): if not isinstance(author, dict): - self.author[i] = {'name': author} + self.author[i] = {"name": author} if not self.title: - raise ValueError('title is required') + raise ValueError("title is required") if not self.id: - raise ValueError('id is required') + raise ValueError("id is required") for author in self.author: - if 'name' not in author: - raise TypeError('author must contain at least a name') + if "name" not in author: + raise TypeError("author must contain at least a name") def add(self, *args, **kwargs): """Add a new entry to the feed. This function can either be called @@ -151,14 +157,14 @@ class AtomFeed(object): if len(args) == 1 and not kwargs and isinstance(args[0], FeedEntry): self.entries.append(args[0]) else: - kwargs['feed_url'] = self.feed_url + kwargs["feed_url"] = self.feed_url self.entries.append(FeedEntry(*args, **kwargs)) def __repr__(self): - return '<%s %r (%d entries)>' % ( + return "<%s %r (%d entries)>" % ( self.__class__.__name__, self.title, - len(self.entries) + len(self.entries), ) def generate(self): @@ -166,7 +172,7 @@ class AtomFeed(object): # atom demands either an author element in every entry or a global one if not self.author: if any(not e.author for e in self.entries): - self.author = ({'name': 'Unknown author'},) + self.author = ({"name": "Unknown author"},) if not self.updated: dates = sorted([entry.updated for entry in self.entries]) @@ -174,56 +180,54 @@ class AtomFeed(object): yield u'<?xml version="1.0" encoding="utf-8"?>\n' yield u'<feed xmlns="http://www.w3.org/2005/Atom">\n' - yield ' ' + _make_text_block('title', self.title, self.title_type) - yield u' <id>%s</id>\n' % escape(self.id) - yield u' <updated>%s</updated>\n' % format_iso8601(self.updated) + yield " " + _make_text_block("title", self.title, self.title_type) + yield u" <id>%s</id>\n" % escape(self.id) + yield u" <updated>%s</updated>\n" % format_iso8601(self.updated) if self.url: yield u' <link href="%s" />\n' % escape(self.url) if self.feed_url: - yield u' <link href="%s" rel="self" />\n' % \ - escape(self.feed_url) + yield u' <link href="%s" rel="self" />\n' % escape(self.feed_url) for link in self.links: - yield u' <link %s/>\n' % ''.join('%s="%s" ' % - (k, escape(link[k])) for k in link) + yield u" <link %s/>\n" % "".join( + '%s="%s" ' % (k, escape(link[k])) for k in link + ) for author in self.author: - yield u' <author>\n' - yield u' <name>%s</name>\n' % escape(author['name']) - if 'uri' in author: - yield u' <uri>%s</uri>\n' % escape(author['uri']) - if 'email' in author: - yield ' <email>%s</email>\n' % escape(author['email']) - yield ' </author>\n' + yield u" <author>\n" + yield u" <name>%s</name>\n" % escape(author["name"]) + if "uri" in author: + yield u" <uri>%s</uri>\n" % escape(author["uri"]) + if "email" in author: + yield " <email>%s</email>\n" % escape(author["email"]) + yield " </author>\n" if self.subtitle: - yield ' ' + _make_text_block('subtitle', self.subtitle, - self.subtitle_type) + yield " " + _make_text_block("subtitle", self.subtitle, self.subtitle_type) if self.icon: - yield u' <icon>%s</icon>\n' % escape(self.icon) + yield u" <icon>%s</icon>\n" % escape(self.icon) if self.logo: - yield u' <logo>%s</logo>\n' % escape(self.logo) + yield u" <logo>%s</logo>\n" % escape(self.logo) if self.rights: - yield ' ' + _make_text_block('rights', self.rights, - self.rights_type) + yield " " + _make_text_block("rights", self.rights, self.rights_type) generator_name, generator_url, generator_version = self.generator if generator_name or generator_url or generator_version: - tmp = [u' <generator'] + tmp = [u" <generator"] if generator_url: tmp.append(u' uri="%s"' % escape(generator_url)) if generator_version: tmp.append(u' version="%s"' % escape(generator_version)) - tmp.append(u'>%s</generator>\n' % escape(generator_name)) - yield u''.join(tmp) + tmp.append(u">%s</generator>\n" % escape(generator_name)) + yield u"".join(tmp) for entry in self.entries: for line in entry.generate(): - yield u' ' + line - yield u'</feed>\n' + yield u" " + line + yield u"</feed>\n" def to_string(self): """Convert the feed into a string.""" - return u''.join(self.generate()) + return u"".join(self.generate()) def get_response(self): """Return a response object for the feed.""" - return BaseResponse(self.to_string(), mimetype='application/atom+xml') + return BaseResponse(self.to_string(), mimetype="application/atom+xml") def __call__(self, environ, start_response): """Use the class as WSGI response object.""" @@ -282,80 +286,77 @@ class FeedEntry(object): def __init__(self, title=None, content=None, feed_url=None, **kwargs): self.title = title - self.title_type = kwargs.get('title_type', 'text') + self.title_type = kwargs.get("title_type", "text") self.content = content - self.content_type = kwargs.get('content_type', 'html') - self.url = kwargs.get('url') - self.id = kwargs.get('id', self.url) - self.updated = kwargs.get('updated') - self.summary = kwargs.get('summary') - self.summary_type = kwargs.get('summary_type', 'html') - self.author = kwargs.get('author', ()) - self.published = kwargs.get('published') - self.rights = kwargs.get('rights') - self.links = kwargs.get('links', []) - self.categories = kwargs.get('categories', []) - self.xml_base = kwargs.get('xml_base', feed_url) - - if not hasattr(self.author, '__iter__') \ - or isinstance(self.author, string_types + (dict,)): + self.content_type = kwargs.get("content_type", "html") + self.url = kwargs.get("url") + self.id = kwargs.get("id", self.url) + self.updated = kwargs.get("updated") + self.summary = kwargs.get("summary") + self.summary_type = kwargs.get("summary_type", "html") + self.author = kwargs.get("author", ()) + self.published = kwargs.get("published") + self.rights = kwargs.get("rights") + self.links = kwargs.get("links", []) + self.categories = kwargs.get("categories", []) + self.xml_base = kwargs.get("xml_base", feed_url) + + if not hasattr(self.author, "__iter__") or isinstance( + self.author, string_types + (dict,) + ): self.author = [self.author] for i, author in enumerate(self.author): if not isinstance(author, dict): - self.author[i] = {'name': author} + self.author[i] = {"name": author} if not self.title: - raise ValueError('title is required') + raise ValueError("title is required") if not self.id: - raise ValueError('id is required') + raise ValueError("id is required") if not self.updated: - raise ValueError('updated is required') + raise ValueError("updated is required") def __repr__(self): - return '<%s %r>' % ( - self.__class__.__name__, - self.title - ) + return "<%s %r>" % (self.__class__.__name__, self.title) def generate(self): """Yields pieces of ATOM XML.""" - base = '' + base = "" if self.xml_base: base = ' xml:base="%s"' % escape(self.xml_base) - yield u'<entry%s>\n' % base - yield u' ' + _make_text_block('title', self.title, self.title_type) - yield u' <id>%s</id>\n' % escape(self.id) - yield u' <updated>%s</updated>\n' % format_iso8601(self.updated) + yield u"<entry%s>\n" % base + yield u" " + _make_text_block("title", self.title, self.title_type) + yield u" <id>%s</id>\n" % escape(self.id) + yield u" <updated>%s</updated>\n" % format_iso8601(self.updated) if self.published: - yield u' <published>%s</published>\n' % \ - format_iso8601(self.published) + yield u" <published>%s</published>\n" % format_iso8601(self.published) if self.url: yield u' <link href="%s" />\n' % escape(self.url) for author in self.author: - yield u' <author>\n' - yield u' <name>%s</name>\n' % escape(author['name']) - if 'uri' in author: - yield u' <uri>%s</uri>\n' % escape(author['uri']) - if 'email' in author: - yield u' <email>%s</email>\n' % escape(author['email']) - yield u' </author>\n' + yield u" <author>\n" + yield u" <name>%s</name>\n" % escape(author["name"]) + if "uri" in author: + yield u" <uri>%s</uri>\n" % escape(author["uri"]) + if "email" in author: + yield u" <email>%s</email>\n" % escape(author["email"]) + yield u" </author>\n" for link in self.links: - yield u' <link %s/>\n' % ''.join('%s="%s" ' % - (k, escape(link[k])) for k in link) + yield u" <link %s/>\n" % "".join( + '%s="%s" ' % (k, escape(link[k])) for k in link + ) for category in self.categories: - yield u' <category %s/>\n' % ''.join('%s="%s" ' % - (k, escape(category[k])) for k in category) + yield u" <category %s/>\n" % "".join( + '%s="%s" ' % (k, escape(category[k])) for k in category + ) if self.summary: - yield u' ' + _make_text_block('summary', self.summary, - self.summary_type) + yield u" " + _make_text_block("summary", self.summary, self.summary_type) if self.content: - yield u' ' + _make_text_block('content', self.content, - self.content_type) - yield u'</entry>\n' + yield u" " + _make_text_block("content", self.content, self.content_type) + yield u"</entry>\n" def to_string(self): """Convert the feed item into a unicode object.""" - return u''.join(self.generate()) + return u"".join(self.generate()) def __str__(self): return self.to_string() diff --git a/src/werkzeug/contrib/cache.py b/src/werkzeug/contrib/cache.py index 8f5a33ae..79c749b5 100644 --- a/src/werkzeug/contrib/cache.py +++ b/src/werkzeug/contrib/cache.py @@ -56,23 +56,27 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ +import errno import os +import platform import re -import errno import tempfile -import platform import warnings from hashlib import md5 from time import time + +from .._compat import integer_types +from .._compat import iteritems +from .._compat import string_types +from .._compat import text_type +from .._compat import to_native +from ..posixemulation import rename + try: import cPickle as pickle except ImportError: # pragma: no cover import pickle -from werkzeug._compat import iteritems, string_types, text_type, \ - integer_types, to_native -from werkzeug.posixemulation import rename - warnings.warn( "'werkzeug.contrib.cache' is deprecated as of version 0.15 and will" " be removed in version 1.0. It has moved to https://github.com" @@ -93,13 +97,12 @@ def _items(mappingorseq): ... assert k*k == v """ - if hasattr(mappingorseq, 'items'): + if hasattr(mappingorseq, "items"): return iteritems(mappingorseq) return mappingorseq class BaseCache(object): - """Baseclass for the cache systems. All the cache systems implement this API or a superset of it. @@ -224,10 +227,10 @@ class BaseCache(object): :param key: the key to check """ raise NotImplementedError( - '%s doesn\'t have an efficient implementation of `has`. That ' - 'means it is impossible to check whether a key exists without ' - 'fully loading the key\'s data. Consider using `self.get` ' - 'explicitly if you don\'t care about performance.' + "%s doesn't have an efficient implementation of `has`. That " + "means it is impossible to check whether a key exists without " + "fully loading the key's data. Consider using `self.get` " + "explicitly if you don't care about performance." ) def clear(self): @@ -267,7 +270,6 @@ class BaseCache(object): class NullCache(BaseCache): - """A cache that doesn't cache. This can be useful for unit testing. :param default_timeout: a dummy parameter that is ignored but exists @@ -279,7 +281,6 @@ class NullCache(BaseCache): class SimpleCache(BaseCache): - """Simple memory cache for single process environments. This class exists mainly for the development server and is not 100% thread safe. It tries to use as many atomic operations as possible and no locks for simplicity @@ -325,15 +326,13 @@ class SimpleCache(BaseCache): def set(self, key, value, timeout=None): expires = self._normalize_timeout(timeout) self._prune() - self._cache[key] = (expires, pickle.dumps(value, - pickle.HIGHEST_PROTOCOL)) + self._cache[key] = (expires, pickle.dumps(value, pickle.HIGHEST_PROTOCOL)) return True def add(self, key, value, timeout=None): expires = self._normalize_timeout(timeout) self._prune() - item = (expires, pickle.dumps(value, - pickle.HIGHEST_PROTOCOL)) + item = (expires, pickle.dumps(value, pickle.HIGHEST_PROTOCOL)) if key in self._cache: return False self._cache.setdefault(key, item) @@ -349,11 +348,11 @@ class SimpleCache(BaseCache): except KeyError: return False -_test_memcached_key = re.compile(r'[^\x00-\x21\xff]{1,250}$').match +_test_memcached_key = re.compile(r"[^\x00-\x21\xff]{1,250}$").match -class MemcachedCache(BaseCache): +class MemcachedCache(BaseCache): """A cache that uses memcached as backend. The first argument can either be an object that resembles the API of a @@ -392,10 +391,10 @@ class MemcachedCache(BaseCache): BaseCache.__init__(self, default_timeout) if servers is None or isinstance(servers, (list, tuple)): if servers is None: - servers = ['127.0.0.1:11211'] + servers = ["127.0.0.1:11211"] self._client = self.import_preferred_memcache_lib(servers) if self._client is None: - raise RuntimeError('no memcache module found') + raise RuntimeError("no memcache module found") else: # NOTE: servers is actually an already initialized memcache # client. @@ -404,7 +403,7 @@ class MemcachedCache(BaseCache): self.key_prefix = to_native(key_prefix) def _normalize_key(self, key): - key = to_native(key, 'utf-8') + key = to_native(key, "utf-8") if self.key_prefix: key = self.key_prefix + key return key @@ -484,7 +483,7 @@ class MemcachedCache(BaseCache): def has(self, key): key = self._normalize_key(key) if _test_memcached_key(key): - return self._client.append(key, '') + return self._client.append(key, "") return False def clear(self): @@ -534,7 +533,6 @@ GAEMemcachedCache = MemcachedCache class RedisCache(BaseCache): - """Uses the Redis key-value store as a cache backend. The first argument can be either a string denoting address of the Redis @@ -570,24 +568,32 @@ class RedisCache(BaseCache): Any additional keyword arguments will be passed to ``redis.Redis``. """ - def __init__(self, host='localhost', port=6379, password=None, - db=0, default_timeout=300, key_prefix=None, **kwargs): + def __init__( + self, + host="localhost", + port=6379, + password=None, + db=0, + default_timeout=300, + key_prefix=None, + **kwargs + ): BaseCache.__init__(self, default_timeout) if host is None: - raise ValueError('RedisCache host parameter may not be None') + raise ValueError("RedisCache host parameter may not be None") if isinstance(host, string_types): try: import redis except ImportError: - raise RuntimeError('no redis module found') - if kwargs.get('decode_responses', None): - raise ValueError('decode_responses is not supported by ' - 'RedisCache.') - self._client = redis.Redis(host=host, port=port, password=password, - db=db, **kwargs) + raise RuntimeError("no redis module found") + if kwargs.get("decode_responses", None): + raise ValueError("decode_responses is not supported by RedisCache.") + self._client = redis.Redis( + host=host, port=port, password=password, db=db, **kwargs + ) else: self._client = host - self.key_prefix = key_prefix or '' + self.key_prefix = key_prefix or "" def _normalize_timeout(self, timeout): timeout = BaseCache._normalize_timeout(self, timeout) @@ -601,8 +607,8 @@ class RedisCache(BaseCache): """ t = type(value) if t in integer_types: - return str(value).encode('ascii') - return b'!' + pickle.dumps(value) + return str(value).encode("ascii") + return b"!" + pickle.dumps(value) def load_object(self, value): """The reversal of :meth:`dump_object`. This might be called with @@ -610,7 +616,7 @@ class RedisCache(BaseCache): """ if value is None: return None - if value.startswith(b'!'): + if value.startswith(b"!"): try: return pickle.loads(value[1:]) except pickle.PickleError: @@ -633,20 +639,19 @@ class RedisCache(BaseCache): timeout = self._normalize_timeout(timeout) dump = self.dump_object(value) if timeout == -1: - result = self._client.set(name=self.key_prefix + key, - value=dump) + result = self._client.set(name=self.key_prefix + key, value=dump) else: - result = self._client.setex(name=self.key_prefix + key, - value=dump, time=timeout) + result = self._client.setex( + name=self.key_prefix + key, value=dump, time=timeout + ) return result def add(self, key, value, timeout=None): timeout = self._normalize_timeout(timeout) dump = self.dump_object(value) - return ( - self._client.setnx(name=self.key_prefix + key, value=dump) - and self._client.expire(name=self.key_prefix + key, time=timeout) - ) + return self._client.setnx( + name=self.key_prefix + key, value=dump + ) and self._client.expire(name=self.key_prefix + key, time=timeout) def set_many(self, mapping, timeout=None): timeout = self._normalize_timeout(timeout) @@ -659,8 +664,7 @@ class RedisCache(BaseCache): if timeout == -1: pipe.set(name=self.key_prefix + key, value=dump) else: - pipe.setex(name=self.key_prefix + key, value=dump, - time=timeout) + pipe.setex(name=self.key_prefix + key, value=dump, time=timeout) return pipe.execute() def delete(self, key): @@ -679,7 +683,7 @@ class RedisCache(BaseCache): def clear(self): status = False if self.key_prefix: - keys = self._client.keys(self.key_prefix + '*') + keys = self._client.keys(self.key_prefix + "*") if keys: status = self._client.delete(*keys) else: @@ -694,7 +698,6 @@ class RedisCache(BaseCache): class FileSystemCache(BaseCache): - """A cache that stores the items on the file system. This cache depends on being the only user of the `cache_dir`. Make absolutely sure that nobody but this cache stores files there or otherwise the cache will @@ -711,12 +714,11 @@ class FileSystemCache(BaseCache): """ #: used for temporary files by the FileSystemCache - _fs_transaction_suffix = '.__wz_cache' + _fs_transaction_suffix = ".__wz_cache" #: keep amount of files in a cache element - _fs_count_file = '__wz_cache_count' + _fs_count_file = "__wz_cache_count" - def __init__(self, cache_dir, threshold=500, default_timeout=300, - mode=0o600): + def __init__(self, cache_dir, threshold=500, default_timeout=300, mode=0o600): BaseCache.__init__(self, default_timeout) self._path = cache_dir self._threshold = threshold @@ -754,11 +756,14 @@ class FileSystemCache(BaseCache): def _list_dir(self): """return a list of (fully qualified) cache filenames """ - mgmt_files = [self._get_filename(name).split('/')[-1] - for name in (self._fs_count_file,)] - return [os.path.join(self._path, fn) for fn in os.listdir(self._path) - if not fn.endswith(self._fs_transaction_suffix) - and fn not in mgmt_files] + mgmt_files = [ + self._get_filename(name).split("/")[-1] for name in (self._fs_count_file,) + ] + return [ + os.path.join(self._path, fn) + for fn in os.listdir(self._path) + if not fn.endswith(self._fs_transaction_suffix) and fn not in mgmt_files + ] def _prune(self): if self._threshold == 0 or not self._file_count > self._threshold: @@ -769,7 +774,7 @@ class FileSystemCache(BaseCache): for idx, fname in enumerate(entries): try: remove = False - with open(fname, 'rb') as f: + with open(fname, "rb") as f: expires = pickle.load(f) remove = (expires != 0 and expires <= now) or idx % 3 == 0 @@ -791,14 +796,14 @@ class FileSystemCache(BaseCache): def _get_filename(self, key): if isinstance(key, text_type): - key = key.encode('utf-8') # XXX unicode review + key = key.encode("utf-8") # XXX unicode review hash = md5(key).hexdigest() return os.path.join(self._path, hash) def get(self, key): filename = self._get_filename(key) try: - with open(filename, 'rb') as f: + with open(filename, "rb") as f: pickle_time = pickle.load(f) if pickle_time == 0 or pickle_time >= time(): return pickle.load(f) @@ -826,9 +831,10 @@ class FileSystemCache(BaseCache): timeout = self._normalize_timeout(timeout) filename = self._get_filename(key) try: - fd, tmp = tempfile.mkstemp(suffix=self._fs_transaction_suffix, - dir=self._path) - with os.fdopen(fd, 'wb') as f: + fd, tmp = tempfile.mkstemp( + suffix=self._fs_transaction_suffix, dir=self._path + ) + with os.fdopen(fd, "wb") as f: pickle.dump(timeout, f, 1) pickle.dump(value, f, pickle.HIGHEST_PROTOCOL) rename(tmp, filename) @@ -855,7 +861,7 @@ class FileSystemCache(BaseCache): def has(self, key): filename = self._get_filename(key) try: - with open(filename, 'rb') as f: + with open(filename, "rb") as f: pickle_time = pickle.load(f) if pickle_time == 0 or pickle_time >= time(): return True @@ -867,7 +873,7 @@ class FileSystemCache(BaseCache): class UWSGICache(BaseCache): - """ Implements the cache using uWSGI's caching framework. + """Implements the cache using uWSGI's caching framework. .. note:: This class cannot be used when running under PyPy, because the uWSGI @@ -880,19 +886,24 @@ class UWSGICache(BaseCache): same instance as the werkzeug app, you only have to provide the name of the cache. """ - def __init__(self, default_timeout=300, cache=''): + + def __init__(self, default_timeout=300, cache=""): BaseCache.__init__(self, default_timeout) - if platform.python_implementation() == 'PyPy': - raise RuntimeError("uWSGI caching does not work under PyPy, see " - "the docs for more details.") + if platform.python_implementation() == "PyPy": + raise RuntimeError( + "uWSGI caching does not work under PyPy, see " + "the docs for more details." + ) try: import uwsgi + self._uwsgi = uwsgi except ImportError: - raise RuntimeError("uWSGI could not be imported, are you " - "running under uWSGI?") + raise RuntimeError( + "uWSGI could not be imported, are you running under uWSGI?" + ) self.cache = cache @@ -906,14 +917,14 @@ class UWSGICache(BaseCache): return self._uwsgi.cache_del(key, self.cache) def set(self, key, value, timeout=None): - return self._uwsgi.cache_update(key, pickle.dumps(value), - self._normalize_timeout(timeout), - self.cache) + return self._uwsgi.cache_update( + key, pickle.dumps(value), self._normalize_timeout(timeout), self.cache + ) def add(self, key, value, timeout=None): - return self._uwsgi.cache_set(key, pickle.dumps(value), - self._normalize_timeout(timeout), - self.cache) + return self._uwsgi.cache_set( + key, pickle.dumps(value), self._normalize_timeout(timeout), self.cache + ) def clear(self): return self._uwsgi.cache_clear(self.cache) diff --git a/src/werkzeug/contrib/fixers.py b/src/werkzeug/contrib/fixers.py index ad386f57..8df0afda 100644 --- a/src/werkzeug/contrib/fixers.py +++ b/src/werkzeug/contrib/fixers.py @@ -28,16 +28,18 @@ This module includes various helpers that fix web server behavior. """ import warnings +from ..datastructures import Headers +from ..datastructures import ResponseCacheControl +from ..http import parse_cache_control_header +from ..http import parse_options_header +from ..http import parse_set_header +from ..middleware.proxy_fix import ProxyFix as _ProxyFix +from ..useragents import UserAgent + try: - from urllib import unquote -except ImportError: from urllib.parse import unquote - -from werkzeug.http import parse_options_header, parse_cache_control_header, \ - parse_set_header -from werkzeug.useragents import UserAgent -from werkzeug.datastructures import Headers, ResponseCacheControl -from werkzeug.middleware.proxy_fix import ProxyFix as _ProxyFix +except ImportError: + from urllib import unquote class CGIRootFix(object): @@ -57,7 +59,7 @@ class CGIRootFix(object): ``LighttpdCGIRootFix``. """ - def __init__(self, app, app_root='/'): + def __init__(self, app, app_root="/"): warnings.warn( "'CGIRootFix' is deprecated as of version 0.15 and will be" " removed in version 1.0.", @@ -68,7 +70,7 @@ class CGIRootFix(object): self.app_root = app_root.strip("/") def __call__(self, environ, start_response): - environ['SCRIPT_NAME'] = self.app_root + environ["SCRIPT_NAME"] = self.app_root return self.app(environ, start_response) @@ -84,7 +86,6 @@ class LighttpdCGIRootFix(CGIRootFix): class PathInfoFromRequestUriFix(object): - """On windows environment variables are limited to the system charset which makes it impossible to store the `PATH_INFO` variable in the environment without loss of information on some systems. @@ -112,14 +113,13 @@ class PathInfoFromRequestUriFix(object): self.app = app def __call__(self, environ, start_response): - for key in 'REQUEST_URL', 'REQUEST_URI', 'UNENCODED_URL': + for key in "REQUEST_URL", "REQUEST_URI", "UNENCODED_URL": if key not in environ: continue request_uri = unquote(environ[key]) - script_name = unquote(environ.get('SCRIPT_NAME', '')) + script_name = unquote(environ.get("SCRIPT_NAME", "")) if request_uri.startswith(script_name): - environ['PATH_INFO'] = request_uri[len(script_name):] \ - .split('?', 1)[0] + environ["PATH_INFO"] = request_uri[len(script_name) :].split("?", 1)[0] break return self.app(environ, start_response) @@ -144,7 +144,6 @@ class ProxyFix(_ProxyFix): class HeaderRewriterFix(object): - """This middleware can remove response headers and add others. This is for example useful to remove the `Date` header from responses if you are using a server that adds that header, no matter if it's present or @@ -182,11 +181,11 @@ class HeaderRewriterFix(object): new_headers.append((key, value)) new_headers += self.add_headers return start_response(status, new_headers, exc_info) + return self.app(environ, rewriting_start_response) class InternetExplorerFix(object): - """This middleware fixes a couple of bugs with Microsoft Internet Explorer. Currently the following fixes are applied: @@ -224,40 +223,40 @@ class InternetExplorerFix(object): def fix_headers(self, environ, headers, status=None): if self.fix_vary: - header = headers.get('content-type', '') + header = headers.get("content-type", "") mimetype, options = parse_options_header(header) - if mimetype not in ('text/html', 'text/plain', 'text/sgml'): - headers.pop('vary', None) + if mimetype not in ("text/html", "text/plain", "text/sgml"): + headers.pop("vary", None) - if self.fix_attach and 'content-disposition' in headers: - pragma = parse_set_header(headers.get('pragma', '')) - pragma.discard('no-cache') + if self.fix_attach and "content-disposition" in headers: + pragma = parse_set_header(headers.get("pragma", "")) + pragma.discard("no-cache") header = pragma.to_header() if not header: - headers.pop('pragma', '') + headers.pop("pragma", "") else: - headers['Pragma'] = header - header = headers.get('cache-control', '') + headers["Pragma"] = header + header = headers.get("cache-control", "") if header: - cc = parse_cache_control_header(header, - cls=ResponseCacheControl) + cc = parse_cache_control_header(header, cls=ResponseCacheControl) cc.no_cache = None cc.no_store = False header = cc.to_header() if not header: - headers.pop('cache-control', '') + headers.pop("cache-control", "") else: - headers['Cache-Control'] = header + headers["Cache-Control"] = header def run_fixed(self, environ, start_response): def fixing_start_response(status, headers, exc_info=None): headers = Headers(headers) self.fix_headers(environ, headers, status) return start_response(status, headers.to_wsgi_list(), exc_info) + return self.app(environ, fixing_start_response) def __call__(self, environ, start_response): ua = UserAgent(environ) - if ua.browser != 'msie': + if ua.browser != "msie": return self.app(environ, start_response) return self.run_fixed(environ, start_response) diff --git a/src/werkzeug/contrib/iterio.py b/src/werkzeug/contrib/iterio.py index 41e1e20e..b6724540 100644 --- a/src/werkzeug/contrib/iterio.py +++ b/src/werkzeug/contrib/iterio.py @@ -41,13 +41,13 @@ r""" """ import warnings +from .._compat import implements_iterator + try: import greenlet except ImportError: greenlet = None -from werkzeug._compat import implements_iterator - warnings.warn( "'werkzeug.contrib.iterio' is deprecated as of version 0.15 and" " will be removed in version 1.0.", @@ -61,19 +61,18 @@ def _mixed_join(iterable, sentinel): iterator = iter(iterable) first_item = next(iterator, sentinel) if isinstance(first_item, bytes): - return first_item + b''.join(iterator) - return first_item + u''.join(iterator) + return first_item + b"".join(iterator) + return first_item + u"".join(iterator) def _newline(reference_string): if isinstance(reference_string, bytes): - return b'\n' - return u'\n' + return b"\n" + return u"\n" @implements_iterator class IterIO(object): - """Instances of this object implement an interface compatible with the standard Python :class:`file` object. Streams are either read-only or write-only depending on how the object is created. @@ -100,7 +99,7 @@ class IterIO(object): `sentinel` parameter was added. """ - def __new__(cls, obj, sentinel=''): + def __new__(cls, obj, sentinel=""): try: iterator = iter(obj) except TypeError: @@ -112,53 +111,53 @@ class IterIO(object): def tell(self): if self.closed: - raise ValueError('I/O operation on closed file') + raise ValueError("I/O operation on closed file") return self.pos def isatty(self): if self.closed: - raise ValueError('I/O operation on closed file') + raise ValueError("I/O operation on closed file") return False def seek(self, pos, mode=0): if self.closed: - raise ValueError('I/O operation on closed file') - raise IOError(9, 'Bad file descriptor') + raise ValueError("I/O operation on closed file") + raise IOError(9, "Bad file descriptor") def truncate(self, size=None): if self.closed: - raise ValueError('I/O operation on closed file') - raise IOError(9, 'Bad file descriptor') + raise ValueError("I/O operation on closed file") + raise IOError(9, "Bad file descriptor") def write(self, s): if self.closed: - raise ValueError('I/O operation on closed file') - raise IOError(9, 'Bad file descriptor') + raise ValueError("I/O operation on closed file") + raise IOError(9, "Bad file descriptor") def writelines(self, list): if self.closed: - raise ValueError('I/O operation on closed file') - raise IOError(9, 'Bad file descriptor') + raise ValueError("I/O operation on closed file") + raise IOError(9, "Bad file descriptor") def read(self, n=-1): if self.closed: - raise ValueError('I/O operation on closed file') - raise IOError(9, 'Bad file descriptor') + raise ValueError("I/O operation on closed file") + raise IOError(9, "Bad file descriptor") def readlines(self, sizehint=0): if self.closed: - raise ValueError('I/O operation on closed file') - raise IOError(9, 'Bad file descriptor') + raise ValueError("I/O operation on closed file") + raise IOError(9, "Bad file descriptor") def readline(self, length=None): if self.closed: - raise ValueError('I/O operation on closed file') - raise IOError(9, 'Bad file descriptor') + raise ValueError("I/O operation on closed file") + raise IOError(9, "Bad file descriptor") def flush(self): if self.closed: - raise ValueError('I/O operation on closed file') - raise IOError(9, 'Bad file descriptor') + raise ValueError("I/O operation on closed file") + raise IOError(9, "Bad file descriptor") def __next__(self): if self.closed: @@ -170,12 +169,11 @@ class IterIO(object): class IterI(IterIO): - """Convert an stream into an iterator.""" - def __new__(cls, func, sentinel=''): + def __new__(cls, func, sentinel=""): if greenlet is None: - raise RuntimeError('IterI requires greenlet support') + raise RuntimeError("IterI requires greenlet support") stream = object.__new__(cls) stream._parent = greenlet.getcurrent() stream._buffer = [] @@ -201,7 +199,7 @@ class IterI(IterIO): def write(self, s): if self.closed: - raise ValueError('I/O operation on closed file') + raise ValueError("I/O operation on closed file") if s: self.pos += len(s) self._buffer.append(s) @@ -212,7 +210,7 @@ class IterI(IterIO): def flush(self): if self.closed: - raise ValueError('I/O operation on closed file') + raise ValueError("I/O operation on closed file") self._flush_impl() def _flush_impl(self): @@ -225,10 +223,9 @@ class IterI(IterIO): class IterO(IterIO): - """Iter output. Wrap an iterator and give it a stream like interface.""" - def __new__(cls, gen, sentinel=''): + def __new__(cls, gen, sentinel=""): self = object.__new__(cls) self._gen = gen self._buf = None @@ -241,8 +238,8 @@ class IterO(IterIO): return self def _buf_append(self, string): - '''Replace string directly without appending to an empty string, - avoiding type issues.''' + """Replace string directly without appending to an empty string, + avoiding type issues.""" if not self._buf: self._buf = string else: @@ -251,12 +248,12 @@ class IterO(IterIO): def close(self): if not self.closed: self.closed = True - if hasattr(self._gen, 'close'): + if hasattr(self._gen, "close"): self._gen.close() def seek(self, pos, mode=0): if self.closed: - raise ValueError('I/O operation on closed file') + raise ValueError("I/O operation on closed file") if mode == 1: pos += self.pos elif mode == 2: @@ -264,10 +261,10 @@ class IterO(IterIO): self.pos = min(self.pos, self.pos + pos) return elif mode != 0: - raise IOError('Invalid argument') + raise IOError("Invalid argument") buf = [] try: - tmp_end_pos = len(self._buf or '') + tmp_end_pos = len(self._buf or "") while pos > tmp_end_pos: item = next(self._gen) tmp_end_pos += len(item) @@ -280,10 +277,10 @@ class IterO(IterIO): def read(self, n=-1): if self.closed: - raise ValueError('I/O operation on closed file') + raise ValueError("I/O operation on closed file") if n < 0: self._buf_append(_mixed_join(self._gen, self.sentinel)) - result = self._buf[self.pos:] + result = self._buf[self.pos :] self.pos += len(result) return result new_pos = self.pos + n @@ -304,13 +301,13 @@ class IterO(IterIO): new_pos = max(0, new_pos) try: - return self._buf[self.pos:new_pos] + return self._buf[self.pos : new_pos] finally: self.pos = min(new_pos, len(self._buf)) def readline(self, length=None): if self.closed: - raise ValueError('I/O operation on closed file') + raise ValueError("I/O operation on closed file") nl_pos = -1 if self._buf: @@ -344,7 +341,7 @@ class IterO(IterIO): if length is not None and self.pos + length < new_pos: new_pos = self.pos + length try: - return self._buf[self.pos:new_pos] + return self._buf[self.pos : new_pos] finally: self.pos = min(new_pos, len(self._buf)) diff --git a/src/werkzeug/contrib/securecookie.py b/src/werkzeug/contrib/securecookie.py index b9ba20ed..c4c9eee2 100644 --- a/src/werkzeug/contrib/securecookie.py +++ b/src/werkzeug/contrib/securecookie.py @@ -88,19 +88,22 @@ r""" :copyright: 2007 Pallets :license: BSD-3-Clause """ -import pickle import base64 +import pickle import warnings +from hashlib import sha1 as _default_hash from hmac import new as hmac from time import time -from hashlib import sha1 as _default_hash -from werkzeug._compat import iteritems, text_type, to_bytes -from werkzeug.urls import url_quote_plus, url_unquote_plus -from werkzeug._internal import _date_to_unix -from werkzeug.contrib.sessions import ModificationTrackingDict -from werkzeug.security import safe_str_cmp -from werkzeug._compat import to_native +from .._compat import iteritems +from .._compat import text_type +from .._compat import to_bytes +from .._compat import to_native +from .._internal import _date_to_unix +from ..contrib.sessions import ModificationTrackingDict +from ..security import safe_str_cmp +from ..urls import url_quote_plus +from ..urls import url_unquote_plus warnings.warn( "'werkzeug.contrib.securecookie' is deprecated as of version 0.15" @@ -112,12 +115,10 @@ warnings.warn( class UnquoteError(Exception): - """Internal exception used to signal failures on quoting.""" class SecureCookie(ModificationTrackingDict): - """Represents a secure cookie. You can subclass this class and provide an alternative mac method. The import thing is that the mac method is a function with a similar interface to the hashlib. Required @@ -164,7 +165,7 @@ class SecureCookie(ModificationTrackingDict): # explicitly convert it into a bytestring because python 2.6 # no longer performs an implicit string conversion on hmac if secret_key is not None: - secret_key = to_bytes(secret_key, 'utf-8') + secret_key = to_bytes(secret_key, "utf-8") self.secret_key = secret_key self.new = new @@ -178,10 +179,10 @@ class SecureCookie(ModificationTrackingDict): ) def __repr__(self): - return '<%s %s%s>' % ( + return "<%s %s%s>" % ( self.__class__.__name__, dict.__repr__(self), - '*' if self.should_save else '' + "*" if self.should_save else "", ) @property @@ -201,7 +202,7 @@ class SecureCookie(ModificationTrackingDict): if cls.serialization_method is not None: value = cls.serialization_method.dumps(value) if cls.quote_base64: - value = b''.join( + value = b"".join( base64.b64encode(to_bytes(value, "utf8")).splitlines() ).strip() return value @@ -236,21 +237,19 @@ class SecureCookie(ModificationTrackingDict): :class:`datetime.datetime` object) """ if self.secret_key is None: - raise RuntimeError('no secret key defined') + raise RuntimeError("no secret key defined") if expires: - self['_expires'] = _date_to_unix(expires) + self["_expires"] = _date_to_unix(expires) result = [] mac = hmac(self.secret_key, None, self.hash_method) for key, value in sorted(self.items()): - result.append(('%s=%s' % ( - url_quote_plus(key), - self.quote(value).decode('ascii') - )).encode('ascii')) - mac.update(b'|' + result[-1]) - return b'?'.join([ - base64.b64encode(mac.digest()).strip(), - b'&'.join(result) - ]) + result.append( + ( + "%s=%s" % (url_quote_plus(key), self.quote(value).decode("ascii")) + ).encode("ascii") + ) + mac.update(b"|" + result[-1]) + return b"?".join([base64.b64encode(mac.digest()).strip(), b"&".join(result)]) @classmethod def unserialize(cls, string, secret_key): @@ -261,24 +260,24 @@ class SecureCookie(ModificationTrackingDict): :return: a new :class:`SecureCookie`. """ if isinstance(string, text_type): - string = string.encode('utf-8', 'replace') + string = string.encode("utf-8", "replace") if isinstance(secret_key, text_type): - secret_key = secret_key.encode('utf-8', 'replace') + secret_key = secret_key.encode("utf-8", "replace") try: - base64_hash, data = string.split(b'?', 1) + base64_hash, data = string.split(b"?", 1) except (ValueError, IndexError): items = () else: items = {} mac = hmac(secret_key, None, cls.hash_method) - for item in data.split(b'&'): - mac.update(b'|' + item) - if b'=' not in item: + for item in data.split(b"&"): + mac.update(b"|" + item) + if b"=" not in item: items = None break - key, value = item.split(b'=', 1) + key, value = item.split(b"=", 1) # try to make the key a string - key = url_unquote_plus(key.decode('ascii')) + key = url_unquote_plus(key.decode("ascii")) try: key = to_native(key) except UnicodeError: @@ -298,17 +297,17 @@ class SecureCookie(ModificationTrackingDict): except UnquoteError: items = () else: - if '_expires' in items: - if time() > items['_expires']: + if "_expires" in items: + if time() > items["_expires"]: items = () else: - del items['_expires'] + del items["_expires"] else: items = () return cls(items, secret_key, False) @classmethod - def load_cookie(cls, request, key='session', secret_key=None): + def load_cookie(cls, request, key="session", secret_key=None): """Loads a :class:`SecureCookie` from a cookie in request. If the cookie is not set, a new :class:`SecureCookie` instanced is returned. @@ -325,9 +324,19 @@ class SecureCookie(ModificationTrackingDict): return cls(secret_key=secret_key) return cls.unserialize(data, secret_key) - def save_cookie(self, response, key='session', expires=None, - session_expires=None, max_age=None, path='/', domain=None, - secure=None, httponly=False, force=False): + def save_cookie( + self, + response, + key="session", + expires=None, + session_expires=None, + max_age=None, + path="/", + domain=None, + secure=None, + httponly=False, + force=False, + ): """Saves the SecureCookie in a cookie on response object. All parameters that are not described here are forwarded directly to :meth:`~BaseResponse.set_cookie`. @@ -341,6 +350,13 @@ class SecureCookie(ModificationTrackingDict): """ if force or self.should_save: data = self.serialize(session_expires or expires) - response.set_cookie(key, data, expires=expires, max_age=max_age, - path=path, domain=domain, secure=secure, - httponly=httponly) + response.set_cookie( + key, + data, + expires=expires, + max_age=max_age, + path=path, + domain=domain, + secure=secure, + httponly=httponly, + ) diff --git a/src/werkzeug/contrib/sessions.py b/src/werkzeug/contrib/sessions.py index 302fd11c..866e827c 100644 --- a/src/werkzeug/contrib/sessions.py +++ b/src/werkzeug/contrib/sessions.py @@ -51,22 +51,26 @@ r""" :copyright: 2007 Pallets :license: BSD-3-Clause """ -import re import os +import re import tempfile import warnings +from hashlib import sha1 from os import path -from time import time +from pickle import dump +from pickle import HIGHEST_PROTOCOL +from pickle import load from random import random -from hashlib import sha1 -from pickle import dump, load, HIGHEST_PROTOCOL +from time import time -from werkzeug.datastructures import CallbackDict -from werkzeug.utils import dump_cookie, parse_cookie -from werkzeug.wsgi import ClosingIterator -from werkzeug.posixemulation import rename -from werkzeug._compat import PY2, text_type -from werkzeug.filesystem import get_filesystem_encoding +from .._compat import PY2 +from .._compat import text_type +from ..datastructures import CallbackDict +from ..filesystem import get_filesystem_encoding +from ..posixemulation import rename +from ..utils import dump_cookie +from ..utils import parse_cookie +from ..wsgi import ClosingIterator warnings.warn( "'werkzeug.contrib.sessions' is deprecated as of version 0.15 and" @@ -76,31 +80,28 @@ warnings.warn( stacklevel=2, ) -_sha1_re = re.compile(r'^[a-f0-9]{40}$') +_sha1_re = re.compile(r"^[a-f0-9]{40}$") def _urandom(): - if hasattr(os, 'urandom'): + if hasattr(os, "urandom"): return os.urandom(30) - return text_type(random()).encode('ascii') + return text_type(random()).encode("ascii") def generate_key(salt=None): if salt is None: - salt = repr(salt).encode('ascii') - return sha1(b''.join([ - salt, - str(time()).encode('ascii'), - _urandom() - ])).hexdigest() + salt = repr(salt).encode("ascii") + return sha1(b"".join([salt, str(time()).encode("ascii"), _urandom()])).hexdigest() class ModificationTrackingDict(CallbackDict): - __slots__ = ('modified',) + __slots__ = ("modified",) def __init__(self, *args, **kwargs): def on_update(self): self.modified = True + self.modified = False CallbackDict.__init__(self, on_update=on_update) dict.update(self, *args, **kwargs) @@ -120,12 +121,12 @@ class ModificationTrackingDict(CallbackDict): class Session(ModificationTrackingDict): - """Subclass of a dict that keeps track of direct object changes. Changes in mutable structures are not tracked, for those you have to set `modified` to `True` by hand. """ - __slots__ = ModificationTrackingDict.__slots__ + ('sid', 'new') + + __slots__ = ModificationTrackingDict.__slots__ + ("sid", "new") def __init__(self, data, sid, new=False): ModificationTrackingDict.__init__(self, data) @@ -133,10 +134,10 @@ class Session(ModificationTrackingDict): self.new = new def __repr__(self): - return '<%s %s%s>' % ( + return "<%s %s%s>" % ( self.__class__.__name__, dict.__repr__(self), - '*' if self.should_save else '' + "*" if self.should_save else "", ) @property @@ -151,7 +152,6 @@ class Session(ModificationTrackingDict): class SessionStore(object): - """Baseclass for all session stores. The Werkzeug contrib module does not implement any useful stores besides the filesystem store, application developers are encouraged to create their own stores. @@ -197,11 +197,10 @@ class SessionStore(object): #: used for temporary files by the filesystem session store -_fs_transaction_suffix = '.__wz_sess' +_fs_transaction_suffix = ".__wz_sess" class FilesystemSessionStore(SessionStore): - """Simple example session store that saves sessions on the filesystem. This store works best on POSIX systems and Windows Vista / Windows Server 2008 and newer. @@ -223,17 +222,23 @@ class FilesystemSessionStore(SessionStore): not yet saved. """ - def __init__(self, path=None, filename_template='werkzeug_%s.sess', - session_class=None, renew_missing=False, mode=0o644): + def __init__( + self, + path=None, + filename_template="werkzeug_%s.sess", + session_class=None, + renew_missing=False, + mode=0o644, + ): SessionStore.__init__(self, session_class) if path is None: path = tempfile.gettempdir() self.path = path if isinstance(filename_template, text_type) and PY2: - filename_template = filename_template.encode( - get_filesystem_encoding()) - assert not filename_template.endswith(_fs_transaction_suffix), \ - 'filename templates may not end with %s' % _fs_transaction_suffix + filename_template = filename_template.encode(get_filesystem_encoding()) + assert not filename_template.endswith(_fs_transaction_suffix), ( + "filename templates may not end with %s" % _fs_transaction_suffix + ) self.filename_template = filename_template self.renew_missing = renew_missing self.mode = mode @@ -248,9 +253,8 @@ class FilesystemSessionStore(SessionStore): def save(self, session): fn = self.get_session_filename(session.sid) - fd, tmp = tempfile.mkstemp(suffix=_fs_transaction_suffix, - dir=self.path) - f = os.fdopen(fd, 'wb') + fd, tmp = tempfile.mkstemp(suffix=_fs_transaction_suffix, dir=self.path) + f = os.fdopen(fd, "wb") try: dump(dict(session), f, HIGHEST_PROTOCOL) finally: @@ -272,7 +276,7 @@ class FilesystemSessionStore(SessionStore): if not self.is_valid_key(sid): return self.new() try: - f = open(self.get_session_filename(sid), 'rb') + f = open(self.get_session_filename(sid), "rb") except IOError: if self.renew_missing: return self.new() @@ -292,9 +296,10 @@ class FilesystemSessionStore(SessionStore): .. versionadded:: 0.6 """ - before, after = self.filename_template.split('%s', 1) - filename_re = re.compile(r'%s(.{5,})%s$' % (re.escape(before), - re.escape(after))) + before, after = self.filename_template.split("%s", 1) + filename_re = re.compile( + r"%s(.{5,})%s$" % (re.escape(before), re.escape(after)) + ) result = [] for filename in os.listdir(self.path): #: this is a session that is still being saved. @@ -307,7 +312,6 @@ class FilesystemSessionStore(SessionStore): class SessionMiddleware(object): - """A simple middleware that puts the session object of a store provided into the WSGI environ. It automatically sets cookies and restores sessions. @@ -323,11 +327,20 @@ class SessionMiddleware(object): compatibility. """ - def __init__(self, app, store, cookie_name='session_id', - cookie_age=None, cookie_expires=None, cookie_path='/', - cookie_domain=None, cookie_secure=None, - cookie_httponly=False, cookie_samesite='Lax', - environ_key='werkzeug.session'): + def __init__( + self, + app, + store, + cookie_name="session_id", + cookie_age=None, + cookie_expires=None, + cookie_path="/", + cookie_domain=None, + cookie_secure=None, + cookie_httponly=False, + cookie_samesite="Lax", + environ_key="werkzeug.session", + ): self.app = app self.store = store self.cookie_name = cookie_name @@ -341,7 +354,7 @@ class SessionMiddleware(object): self.environ_key = environ_key def __call__(self, environ, start_response): - cookie = parse_cookie(environ.get('HTTP_COOKIE', '')) + cookie = parse_cookie(environ.get("HTTP_COOKIE", "")) sid = cookie.get(self.cookie_name, None) if sid is None: session = self.store.new() @@ -352,12 +365,25 @@ class SessionMiddleware(object): def injecting_start_response(status, headers, exc_info=None): if session.should_save: self.store.save(session) - headers.append(('Set-Cookie', dump_cookie(self.cookie_name, - session.sid, self.cookie_age, - self.cookie_expires, self.cookie_path, - self.cookie_domain, self.cookie_secure, - self.cookie_httponly, - samesite=self.cookie_samesite))) + headers.append( + ( + "Set-Cookie", + dump_cookie( + self.cookie_name, + session.sid, + self.cookie_age, + self.cookie_expires, + self.cookie_path, + self.cookie_domain, + self.cookie_secure, + self.cookie_httponly, + samesite=self.cookie_samesite, + ), + ) + ) return start_response(status, headers, exc_info) - return ClosingIterator(self.app(environ, injecting_start_response), - lambda: self.store.save_if_modified(session)) + + return ClosingIterator( + self.app(environ, injecting_start_response), + lambda: self.store.save_if_modified(session), + ) diff --git a/src/werkzeug/contrib/wrappers.py b/src/werkzeug/contrib/wrappers.py index fa3bc565..49b82a71 100644 --- a/src/werkzeug/contrib/wrappers.py +++ b/src/werkzeug/contrib/wrappers.py @@ -23,11 +23,12 @@ import codecs import warnings -from werkzeug.exceptions import BadRequest -from werkzeug.utils import cached_property -from werkzeug.http import dump_options_header, parse_options_header -from werkzeug._compat import wsgi_decoding_dance -from werkzeug.wrappers.json import JSONMixin as _JSONMixin +from .._compat import wsgi_decoding_dance +from ..exceptions import BadRequest +from ..http import dump_options_header +from ..http import parse_options_header +from ..utils import cached_property +from ..wrappers.json import JSONMixin as _JSONMixin def is_known_charset(charset): @@ -87,8 +88,8 @@ class ProtobufRequestMixin(object): DeprecationWarning, stacklevel=2, ) - if 'protobuf' not in self.environ.get('CONTENT_TYPE', ''): - raise BadRequest('Not a Protobuf request') + if "protobuf" not in self.environ.get("CONTENT_TYPE", ""): + raise BadRequest("Not a Protobuf request") obj = proto_type() try: @@ -108,7 +109,8 @@ class RoutingArgsRequestMixin(object): """This request mixin adds support for the wsgiorg routing args `specification`_. - .. _specification: https://wsgi.readthedocs.io/en/latest/specifications/routing_args.html + .. _specification: https://wsgi.readthedocs.io/en/latest/ + specifications/routing_args.html .. deprecated:: 0.15 This mixin will be removed in version 1.0. @@ -122,7 +124,7 @@ class RoutingArgsRequestMixin(object): DeprecationWarning, stacklevel=2, ) - return self.environ.get('wsgiorg.routing_args', (()))[0] + return self.environ.get("wsgiorg.routing_args", (()))[0] def _set_routing_args(self, value): warnings.warn( @@ -133,13 +135,19 @@ class RoutingArgsRequestMixin(object): stacklevel=2, ) if self.shallow: - raise RuntimeError('A shallow request tried to modify the WSGI ' - 'environment. If you really want to do that, ' - 'set `shallow` to False.') - self.environ['wsgiorg.routing_args'] = (value, self.routing_vars) - - routing_args = property(_get_routing_args, _set_routing_args, doc=''' - The positional URL arguments as `tuple`.''') + raise RuntimeError( + "A shallow request tried to modify the WSGI " + "environment. If you really want to do that, " + "set `shallow` to False." + ) + self.environ["wsgiorg.routing_args"] = (value, self.routing_vars) + + routing_args = property( + _get_routing_args, + _set_routing_args, + doc=""" + The positional URL arguments as `tuple`.""", + ) del _get_routing_args, _set_routing_args def _get_routing_vars(self): @@ -150,7 +158,7 @@ class RoutingArgsRequestMixin(object): DeprecationWarning, stacklevel=2, ) - rv = self.environ.get('wsgiorg.routing_args') + rv = self.environ.get("wsgiorg.routing_args") if rv is not None: return rv[1] rv = {} @@ -167,13 +175,19 @@ class RoutingArgsRequestMixin(object): stacklevel=2, ) if self.shallow: - raise RuntimeError('A shallow request tried to modify the WSGI ' - 'environment. If you really want to do that, ' - 'set `shallow` to False.') - self.environ['wsgiorg.routing_args'] = (self.routing_args, value) - - routing_vars = property(_get_routing_vars, _set_routing_vars, doc=''' - The keyword URL arguments as `dict`.''') + raise RuntimeError( + "A shallow request tried to modify the WSGI " + "environment. If you really want to do that, " + "set `shallow` to False." + ) + self.environ["wsgiorg.routing_args"] = (self.routing_args, value) + + routing_vars = property( + _get_routing_vars, + _set_routing_vars, + doc=""" + The keyword URL arguments as `dict`.""", + ) del _get_routing_vars, _set_routing_vars @@ -216,9 +230,10 @@ class ReverseSlashBehaviorRequestMixin(object): DeprecationWarning, stacklevel=2, ) - path = wsgi_decoding_dance(self.environ.get('PATH_INFO') or '', - self.charset, self.encoding_errors) - return path.lstrip('/') + path = wsgi_decoding_dance( + self.environ.get("PATH_INFO") or "", self.charset, self.encoding_errors + ) + return path.lstrip("/") @cached_property def script_root(self): @@ -230,9 +245,10 @@ class ReverseSlashBehaviorRequestMixin(object): DeprecationWarning, stacklevel=2, ) - path = wsgi_decoding_dance(self.environ.get('SCRIPT_NAME') or '', - self.charset, self.encoding_errors) - return path.rstrip('/') + '/' + path = wsgi_decoding_dance( + self.environ.get("SCRIPT_NAME") or "", self.charset, self.encoding_errors + ) + return path.rstrip("/") + "/" class DynamicCharsetRequestMixin(object): @@ -268,7 +284,7 @@ class DynamicCharsetRequestMixin(object): #: is latin1 which is what HTTP specifies as default charset. #: You may however want to set this to utf-8 to better support #: browsers that do not transmit a charset for incoming data. - default_charset = 'latin1' + default_charset = "latin1" def unknown_charset(self, charset): """Called if a charset was provided but is not supported by @@ -279,7 +295,7 @@ class DynamicCharsetRequestMixin(object): :param charset: the charset that was not found. :return: the replacement charset. """ - return 'latin1' + return "latin1" @cached_property def charset(self): @@ -291,10 +307,10 @@ class DynamicCharsetRequestMixin(object): DeprecationWarning, stacklevel=2, ) - header = self.environ.get('CONTENT_TYPE') + header = self.environ.get("CONTENT_TYPE") if header: ct, options = parse_options_header(header) - charset = options.get('charset') + charset = options.get("charset") if charset: if is_known_charset(charset): return charset @@ -327,7 +343,7 @@ class DynamicCharsetResponseMixin(object): """ #: the default charset. - default_charset = 'utf-8' + default_charset = "utf-8" def _get_charset(self): warnings.warn( @@ -337,9 +353,9 @@ class DynamicCharsetResponseMixin(object): DeprecationWarning, stacklevel=2, ) - header = self.headers.get('content-type') + header = self.headers.get("content-type") if header: - charset = parse_options_header(header)[1].get('charset') + charset = parse_options_header(header)[1].get("charset") if charset: return charset return self.default_charset @@ -352,15 +368,18 @@ class DynamicCharsetResponseMixin(object): DeprecationWarning, stacklevel=2, ) - header = self.headers.get('content-type') + header = self.headers.get("content-type") ct, options = parse_options_header(header) if not ct: - raise TypeError('Cannot set charset if Content-Type ' - 'header is missing.') - options['charset'] = charset - self.headers['Content-Type'] = dump_options_header(ct, options) - - charset = property(_get_charset, _set_charset, doc=""" + raise TypeError("Cannot set charset if Content-Type header is missing.") + options["charset"] = charset + self.headers["Content-Type"] = dump_options_header(ct, options) + + charset = property( + _get_charset, + _set_charset, + doc=""" The charset for the response. It's stored inside the - Content-Type header as a parameter.""") + Content-Type header as a parameter.""", + ) del _get_charset, _set_charset diff --git a/src/werkzeug/datastructures.py b/src/werkzeug/datastructures.py index a39397a2..9643db96 100644 --- a/src/werkzeug/datastructures.py +++ b/src/werkzeug/datastructures.py @@ -8,24 +8,32 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import re import codecs import mimetypes +import re from copy import deepcopy from itertools import repeat -from werkzeug._internal import _missing -from werkzeug._compat import BytesIO, collections_abc, iterkeys, itervalues, \ - iteritems, iterlists, PY2, text_type, integer_types, string_types, \ - make_literal_wrapper, to_native -from werkzeug.filesystem import get_filesystem_encoding - +from ._compat import BytesIO +from ._compat import collections_abc +from ._compat import integer_types +from ._compat import iteritems +from ._compat import iterkeys +from ._compat import iterlists +from ._compat import itervalues +from ._compat import make_literal_wrapper +from ._compat import PY2 +from ._compat import string_types +from ._compat import text_type +from ._compat import to_native +from ._internal import _missing +from .filesystem import get_filesystem_encoding -_locale_delim_re = re.compile(r'[_-]') +_locale_delim_re = re.compile(r"[_-]") def is_immutable(self): - raise TypeError('%r objects are immutable' % self.__class__.__name__) + raise TypeError("%r objects are immutable" % self.__class__.__name__) def iter_multi_items(mapping): @@ -52,18 +60,28 @@ def native_itermethods(names): return lambda x: x def setviewmethod(cls, name): - viewmethod_name = 'view%s' % name - viewmethod = lambda self, *a, **kw: ViewItems(self, name, 'view_%s' % name, *a, **kw) - viewmethod.__doc__ = \ - '"""`%s()` object providing a view on %s"""' % (viewmethod_name, name) + viewmethod_name = "view%s" % name + repr_name = "view_%s" % name + + def viewmethod(self, *a, **kw): + return ViewItems(self, name, repr_name, *a, **kw) + + viewmethod.__name__ = viewmethod_name + viewmethod.__doc__ = "`%s()` object providing a view on %s" % ( + viewmethod_name, + name, + ) setattr(cls, viewmethod_name, viewmethod) def setitermethod(cls, name): itermethod = getattr(cls, name) - setattr(cls, 'iter%s' % name, itermethod) - listmethod = lambda self, *a, **kw: list(itermethod(self, *a, **kw)) - listmethod.__doc__ = \ - 'Like :py:meth:`iter%s`, but returns a list.' % name + setattr(cls, "iter%s" % name, itermethod) + + def listmethod(self, *a, **kw): + return list(itermethod(self, *a, **kw)) + + listmethod.__name__ = name + listmethod.__doc__ = "Like :py:meth:`iter%s`, but returns a list." % name setattr(cls, name, listmethod) def wrap(cls): @@ -71,11 +89,11 @@ def native_itermethods(names): setitermethod(cls, name) setviewmethod(cls, name) return cls + return wrap class ImmutableListMixin(object): - """Makes a :class:`list` immutable. .. versionadded:: 0.5 @@ -99,6 +117,7 @@ class ImmutableListMixin(object): def __iadd__(self, other): is_immutable(self) + __imul__ = __iadd__ def __setitem__(self, key, value): @@ -106,6 +125,7 @@ class ImmutableListMixin(object): def append(self, item): is_immutable(self) + remove = append def extend(self, iterable): @@ -125,7 +145,6 @@ class ImmutableListMixin(object): class ImmutableList(ImmutableListMixin, list): - """An immutable :class:`list`. .. versionadded:: 0.5 @@ -134,20 +153,17 @@ class ImmutableList(ImmutableListMixin, list): """ def __repr__(self): - return '%s(%s)' % ( - self.__class__.__name__, - list.__repr__(self), - ) + return "%s(%s)" % (self.__class__.__name__, list.__repr__(self)) class ImmutableDictMixin(object): - """Makes a :class:`dict` immutable. .. versionadded:: 0.5 :private: """ + _hash_cache = None @classmethod @@ -191,7 +207,6 @@ class ImmutableDictMixin(object): class ImmutableMultiDictMixin(ImmutableDictMixin): - """Makes a :class:`MultiDict` immutable. .. versionadded:: 0.5 @@ -222,7 +237,6 @@ class ImmutableMultiDictMixin(ImmutableDictMixin): class UpdateDictMixin(object): - """Makes dicts call `self.on_update` on modifications. .. versionadded:: 0.5 @@ -232,12 +246,13 @@ class UpdateDictMixin(object): on_update = None - def calls_update(name): + def calls_update(name): # noqa: B902 def oncall(self, *args, **kw): rv = getattr(super(UpdateDictMixin, self), name)(*args, **kw) if self.on_update is not None: self.on_update(self) return rv + oncall.__name__ = name return oncall @@ -258,16 +273,15 @@ class UpdateDictMixin(object): self.on_update(self) return rv - __setitem__ = calls_update('__setitem__') - __delitem__ = calls_update('__delitem__') - clear = calls_update('clear') - popitem = calls_update('popitem') - update = calls_update('update') + __setitem__ = calls_update("__setitem__") + __delitem__ = calls_update("__delitem__") + clear = calls_update("clear") + popitem = calls_update("popitem") + update = calls_update("update") del calls_update class TypeConversionDict(dict): - """Works like a regular dict but the :meth:`get` method can perform type conversions. :class:`MultiDict` and :class:`CombinedMultiDict` are subclasses of this class and provide the same feature. @@ -309,7 +323,6 @@ class TypeConversionDict(dict): class ImmutableTypeConversionDict(ImmutableDictMixin, TypeConversionDict): - """Works like a :class:`TypeConversionDict` but does not support modifications. @@ -328,7 +341,6 @@ class ImmutableTypeConversionDict(ImmutableDictMixin, TypeConversionDict): class ViewItems(object): - def __init__(self, multi_dict, method, repr_name, *a, **kw): self.__multi_dict = multi_dict self.__method = method @@ -340,15 +352,14 @@ class ViewItems(object): return getattr(self.__multi_dict, self.__method)(*self.__a, **self.__kw) def __repr__(self): - return '%s(%r)' % (self.__repr_name, list(self.__get_items())) + return "%s(%r)" % (self.__repr_name, list(self.__get_items())) def __iter__(self): return iter(self.__get_items()) -@native_itermethods(['keys', 'values', 'items', 'lists', 'listvalues']) +@native_itermethods(["keys", "values", "items", "lists", "listvalues"]) class MultiDict(TypeConversionDict): - """A :class:`MultiDict` is a dictionary subclass customized to deal with multiple values for the same key which is for example used by the parsing functions in the wrappers. This is necessary because some HTML form @@ -678,17 +689,17 @@ class MultiDict(TypeConversionDict): return self.deepcopy(memo=memo) def __repr__(self): - return '%s(%r)' % (self.__class__.__name__, list(iteritems(self, multi=True))) + return "%s(%r)" % (self.__class__.__name__, list(iteritems(self, multi=True))) class _omd_bucket(object): - """Wraps values in the :class:`OrderedMultiDict`. This makes it possible to keep an order over multiple different keys. It requires a lot of extra memory and slows down access a lot, but makes it possible to access elements in O(1) and iterate in O(n). """ - __slots__ = ('prev', 'key', 'value', 'next') + + __slots__ = ("prev", "key", "value", "next") def __init__(self, omd, key, value): self.prev = omd._last_bucket @@ -713,9 +724,8 @@ class _omd_bucket(object): omd._last_bucket = self.prev -@native_itermethods(['keys', 'values', 'items', 'lists', 'listvalues']) +@native_itermethods(["keys", "values", "items", "lists", "listvalues"]) class OrderedMultiDict(MultiDict): - """Works like a regular :class:`MultiDict` but preserves the order of the fields. To convert the ordered multi dict into a list you can use the :meth:`items` method and pass it ``multi=True``. @@ -822,7 +832,7 @@ class OrderedMultiDict(MultiDict): ptr = ptr.next def listvalues(self): - for key, values in iterlists(self): + for _key, values in iterlists(self): yield values def add(self, key, value): @@ -849,8 +859,7 @@ class OrderedMultiDict(MultiDict): self.add(key, value) def setlistdefault(self, key, default_list=None): - raise TypeError('setlistdefault is unsupported for ' - 'ordered multi dicts') + raise TypeError("setlistdefault is unsupported for ordered multi dicts") def update(self, mapping): for key, value in iter_multi_items(mapping): @@ -893,21 +902,21 @@ class OrderedMultiDict(MultiDict): def _options_header_vkw(value, kw): - return dump_options_header(value, dict((k.replace('_', '-'), v) - for k, v in kw.items())) + return dump_options_header( + value, dict((k.replace("_", "-"), v) for k, v in kw.items()) + ) def _unicodify_header_value(value): if isinstance(value, bytes): - value = value.decode('latin-1') + value = value.decode("latin-1") if not isinstance(value, text_type): value = text_type(value) return value -@native_itermethods(['keys', 'values', 'items']) +@native_itermethods(["keys", "values", "items"]) class Headers(object): - """An object that stores some headers. It has a dict-like interface but is ordered and can store the same keys multiple times. @@ -968,8 +977,7 @@ class Headers(object): raise exceptions.BadRequestKeyError(key) def __eq__(self, other): - return other.__class__ is self.__class__ and \ - set(other._list) == set(self._list) + return other.__class__ is self.__class__ and set(other._list) == set(self._list) __hash__ = None @@ -1007,7 +1015,7 @@ class Headers(object): except KeyError: return default if as_bytes: - rv = rv.encode('latin1') + rv = rv.encode("latin1") if type is None: return rv try: @@ -1036,7 +1044,7 @@ class Headers(object): for k, v in self: if k.lower() == ikey: if as_bytes: - v = v.encode('latin1') + v = v.encode("latin1") if type is not None: try: v = type(v) @@ -1168,10 +1176,12 @@ class Headers(object): def _validate_value(self, value): if not isinstance(value, text_type): - raise TypeError('Value should be unicode.') - if u'\n' in value or u'\r' in value: - raise ValueError('Detected newline in header value. This is ' - 'a potential security problem') + raise TypeError("Value should be unicode.") + if u"\n" in value or u"\r" in value: + raise ValueError( + "Detected newline in header value. This is " + "a potential security problem" + ) def add_header(self, _key, _value, **_kw): """Add a new header tuple to the list. @@ -1210,7 +1220,7 @@ class Headers(object): return listiter = iter(self._list) ikey = _key.lower() - for idx, (old_key, old_value) in enumerate(listiter): + for idx, (old_key, _old_value) in enumerate(listiter): if old_key.lower() == ikey: # replace first ocurrence self._list[idx] = (_key, _value) @@ -1218,7 +1228,7 @@ class Headers(object): else: self._list.append((_key, _value)) return - self._list[idx + 1:] = [t for t in listiter if t[0].lower() != ikey] + self._list[idx + 1 :] = [t for t in listiter if t[0].lower() != ikey] def setdefault(self, key, default): """Returns the value for the key if it is in the dict, otherwise it @@ -1250,17 +1260,18 @@ class Headers(object): else: self.set(key, value) - def to_list(self, charset='iso-8859-1'): + def to_list(self, charset="iso-8859-1"): """Convert the headers into a list suitable for WSGI. .. deprecated:: 0.9 """ from warnings import warn + warn( "'to_list' deprecated as of version 0.9 and will be removed" " in version 1.0. Use 'to_wsgi_list' instead.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) return self.to_wsgi_list() @@ -1273,7 +1284,7 @@ class Headers(object): :return: list """ if PY2: - return [(to_native(k), v.encode('latin1')) for k, v in self] + return [(to_native(k), v.encode("latin1")) for k, v in self] return list(self) def copy(self): @@ -1286,19 +1297,15 @@ class Headers(object): """Returns formatted headers suitable for HTTP transmission.""" strs = [] for key, value in self.to_wsgi_list(): - strs.append('%s: %s' % (key, value)) - strs.append('\r\n') - return '\r\n'.join(strs) + strs.append("%s: %s" % (key, value)) + strs.append("\r\n") + return "\r\n".join(strs) def __repr__(self): - return '%s(%r)' % ( - self.__class__.__name__, - list(self) - ) + return "%s(%r)" % (self.__class__.__name__, list(self)) class ImmutableHeadersMixin(object): - """Makes a :class:`Headers` immutable. We do not mark them as hashable though since the only usecase for this datastructure in Werkzeug is a view on a mutable structure. @@ -1313,10 +1320,12 @@ class ImmutableHeadersMixin(object): def __setitem__(self, key, value): is_immutable(self) + set = __setitem__ def add(self, item): is_immutable(self) + remove = add_header = add def extend(self, iterable): @@ -1336,7 +1345,6 @@ class ImmutableHeadersMixin(object): class EnvironHeaders(ImmutableHeadersMixin, Headers): - """Read only version of the headers from a WSGI environment. This provides the same interface as `Headers` and is constructed from a WSGI environment. @@ -1360,10 +1368,10 @@ class EnvironHeaders(ImmutableHeadersMixin, Headers): # used because get() calls it. if not isinstance(key, string_types): raise KeyError(key) - key = key.upper().replace('-', '_') - if key in ('CONTENT_TYPE', 'CONTENT_LENGTH'): + key = key.upper().replace("-", "_") + if key in ("CONTENT_TYPE", "CONTENT_LENGTH"): return _unicodify_header_value(self.environ[key]) - return _unicodify_header_value(self.environ['HTTP_' + key]) + return _unicodify_header_value(self.environ["HTTP_" + key]) def __len__(self): # the iter is necessary because otherwise list calls our @@ -1372,21 +1380,23 @@ class EnvironHeaders(ImmutableHeadersMixin, Headers): def __iter__(self): for key, value in iteritems(self.environ): - if key.startswith('HTTP_') and key not in \ - ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'): - yield (key[5:].replace('_', '-').title(), - _unicodify_header_value(value)) - elif key in ('CONTENT_TYPE', 'CONTENT_LENGTH') and value: - yield (key.replace('_', '-').title(), - _unicodify_header_value(value)) + if key.startswith("HTTP_") and key not in ( + "HTTP_CONTENT_TYPE", + "HTTP_CONTENT_LENGTH", + ): + yield ( + key[5:].replace("_", "-").title(), + _unicodify_header_value(value), + ) + elif key in ("CONTENT_TYPE", "CONTENT_LENGTH") and value: + yield (key.replace("_", "-").title(), _unicodify_header_value(value)) def copy(self): - raise TypeError('cannot create %r copies' % self.__class__.__name__) + raise TypeError("cannot create %r copies" % self.__class__.__name__) -@native_itermethods(['keys', 'values', 'items', 'lists', 'listvalues']) +@native_itermethods(["keys", "values", "items", "lists", "listvalues"]) class CombinedMultiDict(ImmutableMultiDictMixin, MultiDict): - """A read only :class:`MultiDict` that you can pass multiple :class:`MultiDict` instances as sequence and it will combine the return values of all wrapped dicts: @@ -1417,8 +1427,7 @@ class CombinedMultiDict(ImmutableMultiDictMixin, MultiDict): @classmethod def fromkeys(cls): - raise TypeError('cannot create %r instances by fromkeys' % - cls.__name__) + raise TypeError("cannot create %r instances by fromkeys" % cls.__name__) def __getitem__(self, key): for d in self.dicts: @@ -1471,7 +1480,7 @@ class CombinedMultiDict(ImmutableMultiDictMixin, MultiDict): yield key, value def values(self): - for key, value in iteritems(self): + for _key, value in iteritems(self): yield value def lists(self): @@ -1523,11 +1532,10 @@ class CombinedMultiDict(ImmutableMultiDictMixin, MultiDict): has_key = __contains__ def __repr__(self): - return '%s(%r)' % (self.__class__.__name__, self.dicts) + return "%s(%r)" % (self.__class__.__name__, self.dicts) class FileMultiDict(MultiDict): - """A special :class:`MultiDict` that has convenience methods to add files to it. This is used for :class:`EnvironBuilder` and generally useful for unittesting. @@ -1550,27 +1558,24 @@ class FileMultiDict(MultiDict): if isinstance(file, string_types): if filename is None: filename = file - file = open(file, 'rb') + file = open(file, "rb") if filename and content_type is None: - content_type = mimetypes.guess_type(filename)[0] or \ - 'application/octet-stream' + content_type = ( + mimetypes.guess_type(filename)[0] or "application/octet-stream" + ) value = FileStorage(file, filename, name, content_type) self.add(name, value) class ImmutableDict(ImmutableDictMixin, dict): - """An immutable :class:`dict`. .. versionadded:: 0.5 """ def __repr__(self): - return '%s(%s)' % ( - self.__class__.__name__, - dict.__repr__(self), - ) + return "%s(%s)" % (self.__class__.__name__, dict.__repr__(self)) def copy(self): """Return a shallow mutable copy of this object. Keep in mind that @@ -1584,7 +1589,6 @@ class ImmutableDict(ImmutableDictMixin, dict): class ImmutableMultiDict(ImmutableMultiDictMixin, MultiDict): - """An immutable :class:`MultiDict`. .. versionadded:: 0.5 @@ -1602,7 +1606,6 @@ class ImmutableMultiDict(ImmutableMultiDictMixin, MultiDict): class ImmutableOrderedMultiDict(ImmutableMultiDictMixin, OrderedMultiDict): - """An immutable :class:`OrderedMultiDict`. .. versionadded:: 0.6 @@ -1622,9 +1625,8 @@ class ImmutableOrderedMultiDict(ImmutableMultiDictMixin, OrderedMultiDict): return self -@native_itermethods(['values']) +@native_itermethods(["values"]) class Accept(ImmutableList): - """An :class:`Accept` object is just a list subclass for lists of ``(value, quality)`` tuples. It is automatically sorted by specificity and quality. @@ -1663,17 +1665,20 @@ class Accept(ImmutableList): list.__init__(self, values) else: self.provided = True - values = sorted(values, key=lambda x: (self._specificity(x[0]), x[1], x[0]), - reverse=True) + values = sorted( + values, + key=lambda x: (self._specificity(x[0]), x[1], x[0]), + reverse=True, + ) list.__init__(self, values) def _specificity(self, value): """Returns a tuple describing the value's specificity.""" - return value != '*', + return (value != "*",) def _value_matches(self, value, item): """Check if a value matches a given accept item.""" - return item == '*' or item.lower() == value.lower() + return item == "*" or item.lower() == value.lower() def __getitem__(self, key): """Besides index lookup (getting item n) you can also pass it a string @@ -1697,15 +1702,15 @@ class Accept(ImmutableList): return 0 def __contains__(self, value): - for item, quality in self: + for item, _quality in self: if self._value_matches(value, item): return True return False def __repr__(self): - return '%s([%s])' % ( + return "%s([%s])" % ( self.__class__.__name__, - ', '.join('(%r, %s)' % (x, y) for x, y in self) + ", ".join("(%r, %s)" % (x, y) for x, y in self), ) def index(self, key): @@ -1718,7 +1723,7 @@ class Accept(ImmutableList): with the list API. """ if isinstance(key, string_types): - for idx, (item, quality) in enumerate(self): + for idx, (item, _quality) in enumerate(self): if self._value_matches(key, item): return idx raise ValueError(key) @@ -1744,9 +1749,9 @@ class Accept(ImmutableList): result = [] for value, quality in self: if quality != 1: - value = '%s;q=%s' % (value, quality) + value = "%s;q=%s" % (value, quality) result.append(value) - return ','.join(result) + return ",".join(result) def __str__(self): return self.to_header() @@ -1791,75 +1796,71 @@ class Accept(ImmutableList): class MIMEAccept(Accept): - """Like :class:`Accept` but with special methods and behavior for mimetypes. """ def _specificity(self, value): - return tuple(x != '*' for x in value.split('/', 1)) + return tuple(x != "*" for x in value.split("/", 1)) def _value_matches(self, value, item): def _normalize(x): x = x.lower() - return ('*', '*') if x == '*' else x.split('/', 1) + return ("*", "*") if x == "*" else x.split("/", 1) # this is from the application which is trusted. to avoid developer # frustration we actually check these for valid values - if '/' not in value: - raise ValueError('invalid mimetype %r' % value) + if "/" not in value: + raise ValueError("invalid mimetype %r" % value) value_type, value_subtype = _normalize(value) - if value_type == '*' and value_subtype != '*': - raise ValueError('invalid mimetype %r' % value) + if value_type == "*" and value_subtype != "*": + raise ValueError("invalid mimetype %r" % value) - if '/' not in item: + if "/" not in item: return False item_type, item_subtype = _normalize(item) - if item_type == '*' and item_subtype != '*': + if item_type == "*" and item_subtype != "*": return False return ( - (item_type == item_subtype == '*' - or value_type == value_subtype == '*') - or (item_type == value_type and (item_subtype == '*' - or value_subtype == '*' - or item_subtype == value_subtype)) + item_type == item_subtype == "*" or value_type == value_subtype == "*" + ) or ( + item_type == value_type + and ( + item_subtype == "*" + or value_subtype == "*" + or item_subtype == value_subtype + ) ) @property def accept_html(self): """True if this object accepts HTML.""" return ( - 'text/html' in self - or 'application/xhtml+xml' in self - or self.accept_xhtml + "text/html" in self or "application/xhtml+xml" in self or self.accept_xhtml ) @property def accept_xhtml(self): """True if this object accepts XHTML.""" - return ( - 'application/xhtml+xml' in self - or 'application/xml' in self - ) + return "application/xhtml+xml" in self or "application/xml" in self @property def accept_json(self): """True if this object accepts JSON.""" - return 'application/json' in self + return "application/json" in self class LanguageAccept(Accept): - """Like :class:`Accept` but with normalization for languages.""" def _value_matches(self, value, item): def _normalize(language): return _locale_delim_re.split(language.lower()) - return item == '*' or _normalize(value) == _normalize(item) + return item == "*" or _normalize(value) == _normalize(item) -class CharsetAccept(Accept): +class CharsetAccept(Accept): """Like :class:`Accept` but with normalization for charsets.""" def _value_matches(self, value, item): @@ -1868,20 +1869,22 @@ class CharsetAccept(Accept): return codecs.lookup(name).name except LookupError: return name.lower() - return item == '*' or _normalize(value) == _normalize(item) + + return item == "*" or _normalize(value) == _normalize(item) def cache_property(key, empty, type): """Return a new property object for a cache header. Useful if you want to add support for a cache extension in a subclass.""" - return property(lambda x: x._get_cache_value(key, empty, type), - lambda x, v: x._set_cache_value(key, v, type), - lambda x: x._del_cache_value(key), - 'accessor for %r' % key) + return property( + lambda x: x._get_cache_value(key, empty, type), + lambda x, v: x._set_cache_value(key, v, type), + lambda x: x._del_cache_value(key), + "accessor for %r" % key, + ) class _CacheControl(UpdateDictMixin, dict): - """Subclass of a dict that stores values for a Cache-Control header. It has accessors for all the cache-control directives specified in RFC 2616. The class does not differentiate between request and response directives. @@ -1913,10 +1916,10 @@ class _CacheControl(UpdateDictMixin, dict): no longer existing `CacheControl` class. """ - no_cache = cache_property('no-cache', '*', None) - no_store = cache_property('no-store', None, bool) - max_age = cache_property('max-age', -1, int) - no_transform = cache_property('no-transform', None, None) + no_cache = cache_property("no-cache", "*", None) + no_store = cache_property("no-store", None, bool) + max_age = cache_property("max-age", -1, int) + no_transform = cache_property("no-transform", None, None) def __init__(self, values=(), on_update=None): dict.__init__(self, values or ()) @@ -1966,16 +1969,13 @@ class _CacheControl(UpdateDictMixin, dict): return self.to_header() def __repr__(self): - return '<%s %s>' % ( + return "<%s %s>" % ( self.__class__.__name__, - " ".join( - "%s=%r" % (k, v) for k, v in sorted(self.items()) - ), + " ".join("%s=%r" % (k, v) for k, v in sorted(self.items())), ) class RequestCacheControl(ImmutableDictMixin, _CacheControl): - """A cache control for requests. This is immutable and gives access to all the request-relevant cache control headers. @@ -1989,14 +1989,13 @@ class RequestCacheControl(ImmutableDictMixin, _CacheControl): both for request and response. """ - max_stale = cache_property('max-stale', '*', int) - min_fresh = cache_property('min-fresh', '*', int) - no_transform = cache_property('no-transform', None, None) - only_if_cached = cache_property('only-if-cached', None, bool) + max_stale = cache_property("max-stale", "*", int) + min_fresh = cache_property("min-fresh", "*", int) + no_transform = cache_property("no-transform", None, None) + only_if_cached = cache_property("only-if-cached", None, bool) class ResponseCacheControl(_CacheControl): - """A cache control for responses. Unlike :class:`RequestCacheControl` this is mutable and gives access to response-relevant cache control headers. @@ -2011,11 +2010,11 @@ class ResponseCacheControl(_CacheControl): both for request and response. """ - public = cache_property('public', None, bool) - private = cache_property('private', '*', None) - must_revalidate = cache_property('must-revalidate', None, bool) - proxy_revalidate = cache_property('proxy-revalidate', None, bool) - s_maxage = cache_property('s-maxage', None, None) + public = cache_property("public", None, bool) + private = cache_property("private", "*", None) + must_revalidate = cache_property("must-revalidate", None, bool) + proxy_revalidate = cache_property("proxy-revalidate", None, bool) + s_maxage = cache_property("s-maxage", None, None) # attach cache_property to the _CacheControl as staticmethod @@ -2024,7 +2023,6 @@ _CacheControl.cache_property = staticmethod(cache_property) class CallbackDict(UpdateDictMixin, dict): - """A dict that calls a function passed every time something is changed. The function is passed the dict instance. """ @@ -2034,14 +2032,10 @@ class CallbackDict(UpdateDictMixin, dict): self.on_update = on_update def __repr__(self): - return '<%s %s>' % ( - self.__class__.__name__, - dict.__repr__(self) - ) + return "<%s %s>" % (self.__class__.__name__, dict.__repr__(self)) class HeaderSet(collections_abc.MutableSet): - """Similar to the :class:`ETags` class this implements a set-like structure. Unlike :class:`ETags` this is case insensitive and used for vary, allow, and content-language headers. @@ -2153,7 +2147,7 @@ class HeaderSet(collections_abc.MutableSet): def to_header(self): """Convert the header set into an HTTP header string.""" - return ', '.join(map(quote_header_value, self._headers)) + return ", ".join(map(quote_header_value, self._headers)) def __getitem__(self, idx): return self._headers[idx] @@ -2188,14 +2182,10 @@ class HeaderSet(collections_abc.MutableSet): return self.to_header() def __repr__(self): - return '%s(%r)' % ( - self.__class__.__name__, - self._headers - ) + return "%s(%r)" % (self.__class__.__name__, self._headers) class ETags(collections_abc.Container, collections_abc.Iterable): - """A set that can be used to check if one etag is present in a collection of etags. """ @@ -2245,15 +2235,14 @@ class ETags(collections_abc.Container, collections_abc.Iterable): def to_header(self): """Convert the etags set into a HTTP header string.""" if self.star_tag: - return '*' - return ', '.join( - ['"%s"' % x for x in self._strong] - + ['W/"%s"' % x for x in self._weak] + return "*" + return ", ".join( + ['"%s"' % x for x in self._strong] + ['W/"%s"' % x for x in self._weak] ) def __call__(self, etag=None, data=None, include_weak=False): if [etag, data].count(None) != 1: - raise TypeError('either tag or data required, but at least one') + raise TypeError("either tag or data required, but at least one") if etag is None: etag = generate_etag(data) if include_weak: @@ -2276,11 +2265,10 @@ class ETags(collections_abc.Container, collections_abc.Iterable): return self.contains(etag) def __repr__(self): - return '<%s %r>' % (self.__class__.__name__, str(self)) + return "<%s %r>" % (self.__class__.__name__, str(self)) class IfRange(object): - """Very simple object that represents the `If-Range` header in parsed form. It will either have neither a etag or date or one of either but never both. @@ -2301,13 +2289,13 @@ class IfRange(object): return http_date(self.date) if self.etag is not None: return quote_etag(self.etag) - return '' + return "" def __str__(self): return self.to_header() def __repr__(self): - return '<%s %r>' % (self.__class__.__name__, str(self)) + return "<%s %r>" % (self.__class__.__name__, str(self)) class Range(object): @@ -2339,7 +2327,7 @@ class Range(object): exactly one range and it is satisfiable it returns a ``(start, stop)`` tuple, otherwise `None`. """ - if self.units != 'bytes' or length is None or len(self.ranges) != 1: + if self.units != "bytes" or length is None or len(self.ranges) != 1: return None start, end = self.ranges[0] if end is None: @@ -2362,10 +2350,10 @@ class Range(object): ranges = [] for begin, end in self.ranges: if end is None: - ranges.append('%s-' % begin if begin >= 0 else str(begin)) + ranges.append("%s-" % begin if begin >= 0 else str(begin)) else: - ranges.append('%s-%s' % (begin, end - 1)) - return '%s=%s' % (self.units, ','.join(ranges)) + ranges.append("%s-%s" % (begin, end - 1)) + return "%s=%s" % (self.units, ",".join(ranges)) def to_content_range_header(self, length): """Converts the object into `Content-Range` HTTP header, @@ -2373,32 +2361,33 @@ class Range(object): """ range_for_length = self.range_for_length(length) if range_for_length is not None: - return '%s %d-%d/%d' % (self.units, - range_for_length[0], - range_for_length[1] - 1, length) + return "%s %d-%d/%d" % ( + self.units, + range_for_length[0], + range_for_length[1] - 1, + length, + ) return None def __str__(self): return self.to_header() def __repr__(self): - return '<%s %r>' % (self.__class__.__name__, str(self)) + return "<%s %r>" % (self.__class__.__name__, str(self)) class ContentRange(object): - """Represents the content range header. .. versionadded:: 0.7 """ def __init__(self, units, start, stop, length=None, on_update=None): - assert is_byte_range_valid(start, stop, length), \ - 'Bad range provided' + assert is_byte_range_valid(start, stop, length), "Bad range provided" self.on_update = on_update self.set(start, stop, length, units) - def _callback_property(name): + def _callback_property(name): # noqa: B902 def fget(self): return getattr(self, name) @@ -2406,22 +2395,23 @@ class ContentRange(object): setattr(self, name, value) if self.on_update is not None: self.on_update(self) + return property(fget, fset) #: The units to use, usually "bytes" - units = _callback_property('_units') + units = _callback_property("_units") #: The start point of the range or `None`. - start = _callback_property('_start') + start = _callback_property("_start") #: The stop point of the range (non-inclusive) or `None`. Can only be #: `None` if also start is `None`. - stop = _callback_property('_stop') + stop = _callback_property("_stop") #: The length of the range or `None`. - length = _callback_property('_length') + length = _callback_property("_length") + del _callback_property - def set(self, start, stop, length=None, units='bytes'): + def set(self, start, stop, length=None, units="bytes"): """Simple method to update the ranges.""" - assert is_byte_range_valid(start, stop, length), \ - 'Bad range provided' + assert is_byte_range_valid(start, stop, length), "Bad range provided" self._units = units self._start = start self._stop = stop @@ -2437,19 +2427,14 @@ class ContentRange(object): def to_header(self): if self.units is None: - return '' + return "" if self.length is None: - length = '*' + length = "*" else: length = self.length if self.start is None: - return '%s */%s' % (self.units, length) - return '%s %s-%s/%s' % ( - self.units, - self.start, - self.stop - 1, - length - ) + return "%s */%s" % (self.units, length) + return "%s %s-%s/%s" % (self.units, self.start, self.stop - 1, length) def __nonzero__(self): return self.units is not None @@ -2460,11 +2445,10 @@ class ContentRange(object): return self.to_header() def __repr__(self): - return '<%s %r>' % (self.__class__.__name__, str(self)) + return "<%s %r>" % (self.__class__.__name__, str(self)) class Authorization(ImmutableDictMixin, dict): - """Represents an `Authorization` header sent by the client. You should not create this kind of object yourself but use it when it's returned by the `parse_authorization_header` function. @@ -2481,77 +2465,107 @@ class Authorization(ImmutableDictMixin, dict): dict.__init__(self, data or {}) self.type = auth_type - username = property(lambda x: x.get('username'), doc=''' + username = property( + lambda self: self.get("username"), + doc=""" The username transmitted. This is set for both basic and digest - auth all the time.''') - password = property(lambda x: x.get('password'), doc=''' + auth all the time.""", + ) + password = property( + lambda self: self.get("password"), + doc=""" When the authentication type is basic this is the password - transmitted by the client, else `None`.''') - realm = property(lambda x: x.get('realm'), doc=''' - This is the server realm sent back for HTTP digest auth.''') - nonce = property(lambda x: x.get('nonce'), doc=''' + transmitted by the client, else `None`.""", + ) + realm = property( + lambda self: self.get("realm"), + doc=""" + This is the server realm sent back for HTTP digest auth.""", + ) + nonce = property( + lambda self: self.get("nonce"), + doc=""" The nonce the server sent for digest auth, sent back by the client. A nonce should be unique for every 401 response for HTTP digest - auth.''') - uri = property(lambda x: x.get('uri'), doc=''' + auth.""", + ) + uri = property( + lambda self: self.get("uri"), + doc=""" The URI from Request-URI of the Request-Line; duplicated because proxies are allowed to change the Request-Line in transit. HTTP - digest auth only.''') - nc = property(lambda x: x.get('nc'), doc=''' + digest auth only.""", + ) + nc = property( + lambda self: self.get("nc"), + doc=""" The nonce count value transmitted by clients if a qop-header is - also transmitted. HTTP digest auth only.''') - cnonce = property(lambda x: x.get('cnonce'), doc=''' + also transmitted. HTTP digest auth only.""", + ) + cnonce = property( + lambda self: self.get("cnonce"), + doc=""" If the server sent a qop-header in the ``WWW-Authenticate`` header, the client has to provide this value for HTTP digest auth. - See the RFC for more details.''') - response = property(lambda x: x.get('response'), doc=''' + See the RFC for more details.""", + ) + response = property( + lambda self: self.get("response"), + doc=""" A string of 32 hex digits computed as defined in RFC 2617, which - proves that the user knows a password. Digest auth only.''') - opaque = property(lambda x: x.get('opaque'), doc=''' + proves that the user knows a password. Digest auth only.""", + ) + opaque = property( + lambda self: self.get("opaque"), + doc=""" The opaque header from the server returned unchanged by the client. It is recommended that this string be base64 or hexadecimal data. - Digest auth only.''') - qop = property(lambda x: x.get('qop'), doc=''' + Digest auth only.""", + ) + qop = property( + lambda self: self.get("qop"), + doc=""" Indicates what "quality of protection" the client has applied to the message for HTTP digest auth. Note that this is a single token, - not a quoted list of alternatives as in WWW-Authenticate.''') + not a quoted list of alternatives as in WWW-Authenticate.""", + ) class WWWAuthenticate(UpdateDictMixin, dict): - """Provides simple access to `WWW-Authenticate` headers.""" #: list of keys that require quoting in the generated header - _require_quoting = frozenset(['domain', 'nonce', 'opaque', 'realm', 'qop']) + _require_quoting = frozenset(["domain", "nonce", "opaque", "realm", "qop"]) def __init__(self, auth_type=None, values=None, on_update=None): dict.__init__(self, values or ()) if auth_type: - self['__auth_type__'] = auth_type + self["__auth_type__"] = auth_type self.on_update = on_update - def set_basic(self, realm='authentication required'): + def set_basic(self, realm="authentication required"): """Clear the auth info and enable basic auth.""" dict.clear(self) - dict.update(self, {'__auth_type__': 'basic', 'realm': realm}) + dict.update(self, {"__auth_type__": "basic", "realm": realm}) if self.on_update: self.on_update(self) - def set_digest(self, realm, nonce, qop=('auth',), opaque=None, - algorithm=None, stale=False): + def set_digest( + self, realm, nonce, qop=("auth",), opaque=None, algorithm=None, stale=False + ): """Clear the auth info and enable digest auth.""" d = { - '__auth_type__': 'digest', - 'realm': realm, - 'nonce': nonce, - 'qop': dump_header(qop) + "__auth_type__": "digest", + "realm": realm, + "nonce": nonce, + "qop": dump_header(qop), } if stale: - d['stale'] = 'TRUE' + d["stale"] = "TRUE" if opaque is not None: - d['opaque'] = opaque + d["opaque"] = opaque if algorithm is not None: - d['algorithm'] = algorithm + d["algorithm"] = algorithm dict.clear(self) dict.update(self, d) if self.on_update: @@ -2560,23 +2574,30 @@ class WWWAuthenticate(UpdateDictMixin, dict): def to_header(self): """Convert the stored values into a WWW-Authenticate header.""" d = dict(self) - auth_type = d.pop('__auth_type__', None) or 'basic' - return '%s %s' % (auth_type.title(), ', '.join([ - '%s=%s' % (key, quote_header_value(value, - allow_token=key not in self._require_quoting)) - for key, value in iteritems(d) - ])) + auth_type = d.pop("__auth_type__", None) or "basic" + return "%s %s" % ( + auth_type.title(), + ", ".join( + [ + "%s=%s" + % ( + key, + quote_header_value( + value, allow_token=key not in self._require_quoting + ), + ) + for key, value in iteritems(d) + ] + ), + ) def __str__(self): return self.to_header() def __repr__(self): - return '<%s %r>' % ( - self.__class__.__name__, - self.to_header() - ) + return "<%s %r>" % (self.__class__.__name__, self.to_header()) - def auth_property(name, doc=None): + def auth_property(name, doc=None): # noqa: B902 """A static helper function for subclasses to add extra authentication system properties onto a class:: @@ -2592,70 +2613,90 @@ class WWWAuthenticate(UpdateDictMixin, dict): self.pop(name, None) else: self[name] = str(value) + return property(lambda x: x.get(name), _set_value, doc=doc) - def _set_property(name, doc=None): + def _set_property(name, doc=None): # noqa: B902 def fget(self): def on_update(header_set): if not header_set and name in self: del self[name] elif header_set: self[name] = header_set.to_header() + return parse_set_header(self.get(name), on_update) + return property(fget, doc=doc) - type = auth_property('__auth_type__', doc=''' - The type of the auth mechanism. HTTP currently specifies - `Basic` and `Digest`.''') - realm = auth_property('realm', doc=''' - A string to be displayed to users so they know which username and - password to use. This string should contain at least the name of - the host performing the authentication and might additionally - indicate the collection of users who might have access.''') - domain = _set_property('domain', doc=''' - A list of URIs that define the protection space. If a URI is an - absolute path, it is relative to the canonical root URL of the - server being accessed.''') - nonce = auth_property('nonce', doc=''' + type = auth_property( + "__auth_type__", + doc="""The type of the auth mechanism. HTTP currently specifies + ``Basic`` and ``Digest``.""", + ) + realm = auth_property( + "realm", + doc="""A string to be displayed to users so they know which + username and password to use. This string should contain at + least the name of the host performing the authentication and + might additionally indicate the collection of users who might + have access.""", + ) + domain = _set_property( + "domain", + doc="""A list of URIs that define the protection space. If a URI + is an absolute path, it is relative to the canonical root URL of + the server being accessed.""", + ) + nonce = auth_property( + "nonce", + doc=""" A server-specified data string which should be uniquely generated - each time a 401 response is made. It is recommended that this - string be base64 or hexadecimal data.''') - opaque = auth_property('opaque', doc=''' - A string of data, specified by the server, which should be returned - by the client unchanged in the Authorization header of subsequent - requests with URIs in the same protection space. It is recommended - that this string be base64 or hexadecimal data.''') - algorithm = auth_property('algorithm', doc=''' - A string indicating a pair of algorithms used to produce the digest - and a checksum. If this is not present it is assumed to be "MD5". - If the algorithm is not understood, the challenge should be ignored - (and a different one used, if there is more than one).''') - qop = _set_property('qop', doc=''' - A set of quality-of-privacy directives such as auth and auth-int.''') - - def _get_stale(self): - val = self.get('stale') + each time a 401 response is made. It is recommended that this + string be base64 or hexadecimal data.""", + ) + opaque = auth_property( + "opaque", + doc="""A string of data, specified by the server, which should + be returned by the client unchanged in the Authorization header + of subsequent requests with URIs in the same protection space. + It is recommended that this string be base64 or hexadecimal + data.""", + ) + algorithm = auth_property( + "algorithm", + doc="""A string indicating a pair of algorithms used to produce + the digest and a checksum. If this is not present it is assumed + to be "MD5". If the algorithm is not understood, the challenge + should be ignored (and a different one used, if there is more + than one).""", + ) + qop = _set_property( + "qop", + doc="""A set of quality-of-privacy directives such as auth and + auth-int.""", + ) + + @property + def stale(self): + """A flag, indicating that the previous request from the client + was rejected because the nonce value was stale. + """ + val = self.get("stale") if val is not None: - return val.lower() == 'true' + return val.lower() == "true" - def _set_stale(self, value): + @stale.setter + def stale(self, value): if value is None: - self.pop('stale', None) + self.pop("stale", None) else: - self['stale'] = 'TRUE' if value else 'FALSE' - stale = property(_get_stale, _set_stale, doc=''' - A flag, indicating that the previous request from the client was - rejected because the nonce value was stale.''') - del _get_stale, _set_stale - - # make auth_property a staticmethod so that subclasses of - # `WWWAuthenticate` can use it for new properties. + self["stale"] = "TRUE" if value else "FALSE" + auth_property = staticmethod(auth_property) del _set_property class FileStorage(object): - """The :class:`FileStorage` class is a thin wrapper over incoming files. It is used by the request object to represent uploaded files. All the attributes of the wrapper stream are proxied by the file storage so @@ -2663,9 +2704,15 @@ class FileStorage(object): ``storage.stream.read()``. """ - def __init__(self, stream=None, filename=None, name=None, - content_type=None, content_length=None, - headers=None): + def __init__( + self, + stream=None, + filename=None, + name=None, + content_type=None, + content_length=None, + headers=None, + ): self.name = name self.stream = stream or BytesIO() @@ -2674,41 +2721,39 @@ class FileStorage(object): # skip things like <fdopen>, <stderr> etc. Python marks these # special filenames with angular brackets. if filename is None: - filename = getattr(stream, 'name', None) + filename = getattr(stream, "name", None) s = make_literal_wrapper(filename) - if filename and filename[0] == s('<') and filename[-1] == s('>'): + if filename and filename[0] == s("<") and filename[-1] == s(">"): filename = None # On Python 3 we want to make sure the filename is always unicode. # This might not be if the name attribute is bytes due to the # file being opened from the bytes API. if not PY2 and isinstance(filename, bytes): - filename = filename.decode(get_filesystem_encoding(), - 'replace') + filename = filename.decode(get_filesystem_encoding(), "replace") self.filename = filename if headers is None: headers = Headers() self.headers = headers if content_type is not None: - headers['Content-Type'] = content_type + headers["Content-Type"] = content_type if content_length is not None: - headers['Content-Length'] = str(content_length) + headers["Content-Length"] = str(content_length) def _parse_content_type(self): - if not hasattr(self, '_parsed_content_type'): - self._parsed_content_type = \ - parse_options_header(self.content_type) + if not hasattr(self, "_parsed_content_type"): + self._parsed_content_type = parse_options_header(self.content_type) @property def content_type(self): """The content-type sent in the header. Usually not available""" - return self.headers.get('content-type') + return self.headers.get("content-type") @property def content_length(self): """The content-length sent in the header. Usually not available""" - return int(self.headers.get('content-length') or 0) + return int(self.headers.get("content-length") or 0) @property def mimetype(self): @@ -2748,9 +2793,10 @@ class FileStorage(object): :func:`shutil.copyfileobj`. """ from shutil import copyfileobj + close_dst = False if isinstance(dst, string_types): - dst = open(dst, 'wb') + dst = open(dst, "wb") close_dst = True try: copyfileobj(self.stream, dst, buffer_size) @@ -2767,6 +2813,7 @@ class FileStorage(object): def __nonzero__(self): return bool(self.filename) + __bool__ = __nonzero__ def __getattr__(self, name): @@ -2784,15 +2831,22 @@ class FileStorage(object): return iter(self.stream) def __repr__(self): - return '<%s: %r (%r)>' % ( + return "<%s: %r (%r)>" % ( self.__class__.__name__, self.filename, - self.content_type + self.content_type, ) # circular dependencies -from werkzeug.http import dump_options_header, dump_header, generate_etag, \ - quote_header_value, parse_set_header, unquote_etag, quote_etag, \ - parse_options_header, http_date, is_byte_range_valid -from werkzeug import exceptions +from . import exceptions +from .http import dump_header +from .http import dump_options_header +from .http import generate_etag +from .http import http_date +from .http import is_byte_range_valid +from .http import parse_options_header +from .http import parse_set_header +from .http import quote_etag +from .http import quote_header_value +from .http import unquote_etag diff --git a/src/werkzeug/debug/__init__.py b/src/werkzeug/debug/__init__.py index cb984b03..beb7729e 100644 --- a/src/werkzeug/debug/__init__.py +++ b/src/werkzeug/debug/__init__.py @@ -8,33 +8,35 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ +import getpass +import hashlib +import json +import mimetypes import os import pkgutil import re import sys -import uuid -import json import time -import getpass -import hashlib -import mimetypes +import uuid from itertools import chain -from os.path import join, basename -from werkzeug.wrappers import BaseRequest as Request, BaseResponse as Response -from werkzeug.http import parse_cookie -from werkzeug.debug.tbtools import get_current_traceback, render_console_html -from werkzeug.debug.console import Console -from werkzeug.security import gen_salt -from werkzeug._internal import _log -from werkzeug._compat import text_type +from os.path import basename +from os.path import join - -# DEPRECATED -from werkzeug.debug.repr import debug_repr as _debug_repr +from .._compat import text_type +from .._internal import _log +from ..http import parse_cookie +from ..security import gen_salt +from ..wrappers import BaseRequest as Request +from ..wrappers import BaseResponse as Response +from .console import Console +from .repr import debug_repr as _debug_repr +from .tbtools import get_current_traceback +from .tbtools import render_console_html def debug_repr(*args, **kwargs): import warnings + warnings.warn( "'debug_repr' has moved to 'werkzeug.debug.repr.debug_repr'" " as of version 0.7. This old import will be removed in version" @@ -51,8 +53,8 @@ PIN_TIME = 60 * 60 * 24 * 7 def hash_pin(pin): if isinstance(pin, text_type): - pin = pin.encode('utf-8', 'replace') - return hashlib.md5(pin + b'shittysalt').hexdigest()[:12] + pin = pin.encode("utf-8", "replace") + return hashlib.md5(pin + b"shittysalt").hexdigest()[:12] _machine_id = None @@ -67,9 +69,9 @@ def get_machine_id(): def _generate(): # Potential sources of secret information on linux. The machine-id # is stable across boots, the boot id is not - for filename in '/etc/machine-id', '/proc/sys/kernel/random/boot_id': + for filename in "/etc/machine-id", "/proc/sys/kernel/random/boot_id": try: - with open(filename, 'rb') as f: + with open(filename, "rb") as f: return f.readline().strip() except IOError: continue @@ -81,8 +83,10 @@ def get_machine_id(): # Google App Engine # See https://github.com/pallets/werkzeug/issues/925 from subprocess import Popen, PIPE - dump = Popen(['ioreg', '-c', 'IOPlatformExpertDevice', '-d', '2'], - stdout=PIPE).communicate()[0] + + dump = Popen( + ["ioreg", "-c", "IOPlatformExpertDevice", "-d", "2"], stdout=PIPE + ).communicate()[0] match = re.search(b'"serial-number" = <([^>]+)', dump) if match is not None: return match.group(1) @@ -100,12 +104,15 @@ def get_machine_id(): pass if wr is not None: try: - with wr.OpenKey(wr.HKEY_LOCAL_MACHINE, - 'SOFTWARE\\Microsoft\\Cryptography', 0, - wr.KEY_READ | wr.KEY_WOW64_64KEY) as rk: - machineGuid, wrType = wr.QueryValueEx(rk, 'MachineGuid') + with wr.OpenKey( + wr.HKEY_LOCAL_MACHINE, + "SOFTWARE\\Microsoft\\Cryptography", + 0, + wr.KEY_READ | wr.KEY_WOW64_64KEY, + ) as rk: + machineGuid, wrType = wr.QueryValueEx(rk, "MachineGuid") if wrType == wr.REG_SZ: - return machineGuid.encode('utf-8') + return machineGuid.encode("utf-8") else: return machineGuid except WindowsError: @@ -116,7 +123,6 @@ def get_machine_id(): class _ConsoleFrame(object): - """Helper class so that we can reuse the frame console code for the standalone console. """ @@ -134,24 +140,23 @@ def get_pin_and_cookie_name(app): Second item in the resulting tuple is the cookie name for remembering. """ - pin = os.environ.get('WERKZEUG_DEBUG_PIN') + pin = os.environ.get("WERKZEUG_DEBUG_PIN") rv = None num = None # Pin was explicitly disabled - if pin == 'off': + if pin == "off": return None, None # Pin was provided explicitly - if pin is not None and pin.replace('-', '').isdigit(): + if pin is not None and pin.replace("-", "").isdigit(): # If there are separators in the pin, return it directly - if '-' in pin: + if "-" in pin: rv = pin else: num = pin - modname = getattr(app, '__module__', - getattr(app.__class__, '__module__')) + modname = getattr(app, "__module__", getattr(app.__class__, "__module__")) try: # `getpass.getuser()` imports the `pwd` module, @@ -167,42 +172,41 @@ def get_pin_and_cookie_name(app): probably_public_bits = [ username, modname, - getattr(app, '__name__', getattr(app.__class__, '__name__')), - getattr(mod, '__file__', None), + getattr(app, "__name__", getattr(app.__class__, "__name__")), + getattr(mod, "__file__", None), ] # This information is here to make it harder for an attacker to # guess the cookie name. They are unlikely to be contained anywhere # within the unauthenticated debug page. - private_bits = [ - str(uuid.getnode()), - get_machine_id(), - ] + private_bits = [str(uuid.getnode()), get_machine_id()] h = hashlib.md5() for bit in chain(probably_public_bits, private_bits): if not bit: continue if isinstance(bit, text_type): - bit = bit.encode('utf-8') + bit = bit.encode("utf-8") h.update(bit) - h.update(b'cookiesalt') + h.update(b"cookiesalt") - cookie_name = '__wzd' + h.hexdigest()[:20] + cookie_name = "__wzd" + h.hexdigest()[:20] # If we need to generate a pin we salt it a bit more so that we don't # end up with the same value and generate out 9 digits if num is None: - h.update(b'pinsalt') - num = ('%09d' % int(h.hexdigest(), 16))[:9] + h.update(b"pinsalt") + num = ("%09d" % int(h.hexdigest(), 16))[:9] # Format the pincode in groups of digits for easier remembering if # we don't have a result yet. if rv is None: for group_size in 5, 4, 3: if len(num) % group_size == 0: - rv = '-'.join(num[x:x + group_size].rjust(group_size, '0') - for x in range(0, len(num), group_size)) + rv = "-".join( + num[x : x + group_size].rjust(group_size, "0") + for x in range(0, len(num), group_size) + ) break else: rv = num @@ -240,12 +244,21 @@ class DebuggedApplication(object): :param pin_logging: enables the logging of the pin system. """ - def __init__(self, app, evalex=False, request_key='werkzeug.request', - console_path='/console', console_init_func=None, - show_hidden_frames=False, lodgeit_url=None, - pin_security=True, pin_logging=True): + def __init__( + self, + app, + evalex=False, + request_key="werkzeug.request", + console_path="/console", + console_init_func=None, + show_hidden_frames=False, + lodgeit_url=None, + pin_security=True, + pin_logging=True, + ): if lodgeit_url is not None: from warnings import warn + warn( "'lodgeit_url' is no longer used as of version 0.9 and" " will be removed in version 1.0. Werkzeug uses" @@ -269,19 +282,17 @@ class DebuggedApplication(object): self.pin_logging = pin_logging if pin_security: # Print out the pin for the debugger on standard out. - if os.environ.get('WERKZEUG_RUN_MAIN') == 'true' and \ - pin_logging: - _log('warning', ' * Debugger is active!') + if os.environ.get("WERKZEUG_RUN_MAIN") == "true" and pin_logging: + _log("warning", " * Debugger is active!") if self.pin is None: - _log('warning', ' * Debugger PIN disabled. ' - 'DEBUGGER UNSECURED!') + _log("warning", " * Debugger PIN disabled. DEBUGGER UNSECURED!") else: - _log('info', ' * Debugger PIN: %s' % self.pin) + _log("info", " * Debugger PIN: %s" % self.pin) else: self.pin = None def _get_pin(self): - if not hasattr(self, '_pin'): + if not hasattr(self, "_pin"): self._pin, self._pin_cookie = get_pin_and_cookie_name(self.app) return self._pin @@ -294,7 +305,7 @@ class DebuggedApplication(object): @property def pin_cookie_name(self): """The name of the pin cookie.""" - if not hasattr(self, '_pin_cookie'): + if not hasattr(self, "_pin_cookie"): self._pin, self._pin_cookie = get_pin_and_cookie_name(self.app) return self._pin_cookie @@ -305,46 +316,51 @@ class DebuggedApplication(object): app_iter = self.app(environ, start_response) for item in app_iter: yield item - if hasattr(app_iter, 'close'): + if hasattr(app_iter, "close"): app_iter.close() except Exception: - if hasattr(app_iter, 'close'): + if hasattr(app_iter, "close"): app_iter.close() traceback = get_current_traceback( - skip=1, show_hidden_frames=self.show_hidden_frames, - ignore_system_exceptions=True) + skip=1, + show_hidden_frames=self.show_hidden_frames, + ignore_system_exceptions=True, + ) for frame in traceback.frames: self.frames[frame.id] = frame self.tracebacks[traceback.id] = traceback try: - start_response('500 INTERNAL SERVER ERROR', [ - ('Content-Type', 'text/html; charset=utf-8'), - # Disable Chrome's XSS protection, the debug - # output can cause false-positives. - ('X-XSS-Protection', '0'), - ]) + start_response( + "500 INTERNAL SERVER ERROR", + [ + ("Content-Type", "text/html; charset=utf-8"), + # Disable Chrome's XSS protection, the debug + # output can cause false-positives. + ("X-XSS-Protection", "0"), + ], + ) except Exception: # if we end up here there has been output but an error # occurred. in that situation we can do nothing fancy any # more, better log something into the error log and fall # back gracefully. - environ['wsgi.errors'].write( - 'Debugging middleware caught exception in streamed ' - 'response at a point where response headers were already ' - 'sent.\n') + environ["wsgi.errors"].write( + "Debugging middleware caught exception in streamed " + "response at a point where response headers were already " + "sent.\n" + ) else: is_trusted = bool(self.check_pin_trust(environ)) - yield traceback.render_full(evalex=self.evalex, - evalex_trusted=is_trusted, - secret=self.secret) \ - .encode('utf-8', 'replace') + yield traceback.render_full( + evalex=self.evalex, evalex_trusted=is_trusted, secret=self.secret + ).encode("utf-8", "replace") - traceback.log(environ['wsgi.errors']) + traceback.log(environ["wsgi.errors"]) def execute_command(self, request, command, frame): """Execute a command in a console.""" - return Response(frame.console.eval(command), mimetype='text/html') + return Response(frame.console.eval(command), mimetype="text/html") def display_console(self, request): """Display a standalone shell.""" @@ -353,30 +369,30 @@ class DebuggedApplication(object): ns = {} else: ns = dict(self.console_init_func()) - ns.setdefault('app', self.app) + ns.setdefault("app", self.app) self.frames[0] = _ConsoleFrame(ns) is_trusted = bool(self.check_pin_trust(request.environ)) - return Response(render_console_html(secret=self.secret, - evalex_trusted=is_trusted), - mimetype='text/html') + return Response( + render_console_html(secret=self.secret, evalex_trusted=is_trusted), + mimetype="text/html", + ) def paste_traceback(self, request, traceback): """Paste the traceback and return a JSON response.""" rv = traceback.paste() - return Response(json.dumps(rv), mimetype='application/json') + return Response(json.dumps(rv), mimetype="application/json") def get_resource(self, request, filename): """Return a static resource from the shared folder.""" - filename = join('shared', basename(filename)) + filename = join("shared", basename(filename)) try: data = pkgutil.get_data(__package__, filename) except OSError: data = None if data is not None: - mimetype = mimetypes.guess_type(filename)[0] \ - or 'application/octet-stream' + mimetype = mimetypes.guess_type(filename)[0] or "application/octet-stream" return Response(data, mimetype=mimetype) - return Response('Not Found', status=404) + return Response("Not Found", status=404) def check_pin_trust(self, environ): """Checks if the request passed the pin test. This returns `True` if the @@ -387,9 +403,9 @@ class DebuggedApplication(object): if self.pin is None: return True val = parse_cookie(environ).get(self.pin_cookie_name) - if not val or '|' not in val: + if not val or "|" not in val: return False - ts, pin_hash = val.split('|', 1) + ts, pin_hash = val.split("|", 1) if not ts.isdigit(): return False if pin_hash != hash_pin(self.pin): @@ -426,23 +442,23 @@ class DebuggedApplication(object): # Otherwise go through pin based authentication else: - entered_pin = request.args.get('pin') - if entered_pin.strip().replace('-', '') == \ - self.pin.replace('-', ''): + entered_pin = request.args.get("pin") + if entered_pin.strip().replace("-", "") == self.pin.replace("-", ""): self._failed_pin_auth = 0 auth = True else: self._fail_pin_auth() - rv = Response(json.dumps({ - 'auth': auth, - 'exhausted': exhausted, - }), mimetype='application/json') + rv = Response( + json.dumps({"auth": auth, "exhausted": exhausted}), + mimetype="application/json", + ) if auth: - rv.set_cookie(self.pin_cookie_name, '%s|%s' % ( - int(time.time()), - hash_pin(self.pin) - ), httponly=True) + rv.set_cookie( + self.pin_cookie_name, + "%s|%s" % (int(time.time()), hash_pin(self.pin)), + httponly=True, + ) elif bad_cookie: rv.delete_cookie(self.pin_cookie_name) return rv @@ -450,10 +466,11 @@ class DebuggedApplication(object): def log_pin_request(self): """Log the pin if needed.""" if self.pin_logging and self.pin is not None: - _log('info', ' * To enable the debugger you need to ' - 'enter the security pin:') - _log('info', ' * Debugger pin code: %s' % self.pin) - return Response('') + _log( + "info", " * To enable the debugger you need to enter the security pin:" + ) + _log("info", " * Debugger pin code: %s" % self.pin) + return Response("") def __call__(self, environ, start_response): """Dispatch the requests.""" @@ -462,26 +479,32 @@ class DebuggedApplication(object): # any more! request = Request(environ) response = self.debug_application - if request.args.get('__debugger__') == 'yes': - cmd = request.args.get('cmd') - arg = request.args.get('f') - secret = request.args.get('s') - traceback = self.tracebacks.get(request.args.get('tb', type=int)) - frame = self.frames.get(request.args.get('frm', type=int)) - if cmd == 'resource' and arg: + if request.args.get("__debugger__") == "yes": + cmd = request.args.get("cmd") + arg = request.args.get("f") + secret = request.args.get("s") + traceback = self.tracebacks.get(request.args.get("tb", type=int)) + frame = self.frames.get(request.args.get("frm", type=int)) + if cmd == "resource" and arg: response = self.get_resource(request, arg) - elif cmd == 'paste' and traceback is not None and \ - secret == self.secret: + elif cmd == "paste" and traceback is not None and secret == self.secret: response = self.paste_traceback(request, traceback) - elif cmd == 'pinauth' and secret == self.secret: + elif cmd == "pinauth" and secret == self.secret: response = self.pin_auth(request) - elif cmd == 'printpin' and secret == self.secret: + elif cmd == "printpin" and secret == self.secret: response = self.log_pin_request() - elif self.evalex and cmd is not None and frame is not None \ - and self.secret == secret and \ - self.check_pin_trust(environ): + elif ( + self.evalex + and cmd is not None + and frame is not None + and self.secret == secret + and self.check_pin_trust(environ) + ): response = self.execute_command(request, cmd, frame) - elif self.evalex and self.console_path is not None and \ - request.path == self.console_path: + elif ( + self.evalex + and self.console_path is not None + and request.path == self.console_path + ): response = self.display_console(request) return response(environ, start_response) diff --git a/src/werkzeug/debug/console.py b/src/werkzeug/debug/console.py index 2915aea1..adbd170b 100644 --- a/src/werkzeug/debug/console.py +++ b/src/werkzeug/debug/console.py @@ -8,20 +8,21 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import sys import code +import sys from types import CodeType -from werkzeug.utils import escape -from werkzeug.local import Local -from werkzeug.debug.repr import debug_repr, dump, helper +from ..local import Local +from ..utils import escape +from .repr import debug_repr +from .repr import dump +from .repr import helper _local = Local() class HTMLStringO(object): - """A StringO version that HTML escapes on write.""" def __init__(self): @@ -41,46 +42,46 @@ class HTMLStringO(object): def readline(self): if len(self._buffer) == 0: - return '' + return "" ret = self._buffer[0] del self._buffer[0] return ret def reset(self): - val = ''.join(self._buffer) + val = "".join(self._buffer) del self._buffer[:] return val def _write(self, x): if isinstance(x, bytes): - x = x.decode('utf-8', 'replace') + x = x.decode("utf-8", "replace") self._buffer.append(x) def write(self, x): self._write(escape(x)) def writelines(self, x): - self._write(escape(''.join(x))) + self._write(escape("".join(x))) class ThreadedStream(object): - """Thread-local wrapper for sys.stdout for the interactive console.""" + @staticmethod def push(): if not isinstance(sys.stdout, ThreadedStream): sys.stdout = ThreadedStream() _local.stream = HTMLStringO() - push = staticmethod(push) + @staticmethod def fetch(): try: stream = _local.stream except AttributeError: - return '' + return "" return stream.reset() - fetch = staticmethod(fetch) + @staticmethod def displayhook(obj): try: stream = _local.stream @@ -89,18 +90,17 @@ class ThreadedStream(object): # stream._write bypasses escaping as debug_repr is # already generating HTML for us. if obj is not None: - _local._current_ipy.locals['_'] = obj + _local._current_ipy.locals["_"] = obj stream._write(debug_repr(obj)) - displayhook = staticmethod(displayhook) def __setattr__(self, name, value): - raise AttributeError('read only attribute %s' % name) + raise AttributeError("read only attribute %s" % name) def __dir__(self): return dir(sys.__stdout__) def __getattribute__(self, name): - if name == '__members__': + if name == "__members__": return dir(sys.__stdout__) try: stream = _local.stream @@ -118,7 +118,6 @@ sys.displayhook = ThreadedStream.displayhook class _ConsoleLoader(object): - def __init__(self): self._storage = {} @@ -143,29 +142,30 @@ def _wrap_compiler(console): code = compile(source, filename, symbol) console.loader.register(code, source) return code + console.compile = func class _InteractiveConsole(code.InteractiveInterpreter): - def __init__(self, globals, locals): code.InteractiveInterpreter.__init__(self, locals) self.globals = dict(globals) - self.globals['dump'] = dump - self.globals['help'] = helper - self.globals['__loader__'] = self.loader = _ConsoleLoader() + self.globals["dump"] = dump + self.globals["help"] = helper + self.globals["__loader__"] = self.loader = _ConsoleLoader() self.more = False self.buffer = [] _wrap_compiler(self) def runsource(self, source): - source = source.rstrip() + '\n' + source = source.rstrip() + "\n" ThreadedStream.push() - prompt = '... ' if self.more else '>>> ' + prompt = "... " if self.more else ">>> " try: - source_to_eval = ''.join(self.buffer + [source]) - if code.InteractiveInterpreter.runsource(self, - source_to_eval, '<debugger>', 'single'): + source_to_eval = "".join(self.buffer + [source]) + if code.InteractiveInterpreter.runsource( + self, source_to_eval, "<debugger>", "single" + ): self.more = True self.buffer.append(source) else: @@ -182,12 +182,14 @@ class _InteractiveConsole(code.InteractiveInterpreter): self.showtraceback() def showtraceback(self): - from werkzeug.debug.tbtools import get_current_traceback + from .tbtools import get_current_traceback + tb = get_current_traceback(skip=1) sys.stdout._write(tb.render_summary()) def showsyntaxerror(self, filename=None): - from werkzeug.debug.tbtools import get_current_traceback + from .tbtools import get_current_traceback + tb = get_current_traceback(skip=4) sys.stdout._write(tb.render_summary()) @@ -196,7 +198,6 @@ class _InteractiveConsole(code.InteractiveInterpreter): class Console(object): - """An interactive console.""" def __init__(self, globals=None, locals=None): diff --git a/src/werkzeug/debug/repr.py b/src/werkzeug/debug/repr.py index e82d87c4..d7a7285c 100644 --- a/src/werkzeug/debug/repr.py +++ b/src/werkzeug/debug/repr.py @@ -13,37 +13,38 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import sys -import re import codecs +import re +import sys +from collections import deque from traceback import format_exception_only -try: - from collections import deque -except ImportError: # pragma: no cover - deque = None -from werkzeug.utils import escape -from werkzeug._compat import iteritems, PY2, text_type, integer_types, \ - string_types + +from .._compat import integer_types +from .._compat import iteritems +from .._compat import PY2 +from .._compat import string_types +from .._compat import text_type +from ..utils import escape missing = object() -_paragraph_re = re.compile(r'(?:\r\n|\r|\n){2,}') +_paragraph_re = re.compile(r"(?:\r\n|\r|\n){2,}") RegexType = type(_paragraph_re) -HELP_HTML = '''\ +HELP_HTML = """\ <div class=box> <h3>%(title)s</h3> <pre class=help>%(text)s</pre> </div>\ -''' -OBJECT_DUMP_HTML = '''\ +""" +OBJECT_DUMP_HTML = """\ <div class=box> <h3>%(title)s</h3> %(repr)s <table>%(items)s</table> </div>\ -''' +""" def debug_repr(obj): @@ -64,31 +65,31 @@ def dump(obj=missing): class _Helper(object): - """Displays an HTML version of the normal help, for the interactive debugger only because it requires a patched sys.stdout. """ def __repr__(self): - return 'Type help(object) for help about object.' + return "Type help(object) for help about object." def __call__(self, topic=None): if topic is None: - sys.stdout._write('<span class=help>%s</span>' % repr(self)) + sys.stdout._write("<span class=help>%s</span>" % repr(self)) return import pydoc + pydoc.help(topic) rv = sys.stdout.reset() if isinstance(rv, bytes): - rv = rv.decode('utf-8', 'ignore') + rv = rv.decode("utf-8", "ignore") paragraphs = _paragraph_re.split(rv) if len(paragraphs) > 1: title = paragraphs[0] - text = '\n\n'.join(paragraphs[1:]) + text = "\n\n".join(paragraphs[1:]) else: # pragma: no cover - title = 'Help' + title = "Help" text = paragraphs[0] - sys.stdout._write(HELP_HTML % {'title': title, 'text': text}) + sys.stdout._write(HELP_HTML % {"title": title, "text": text}) helper = _Helper() @@ -101,55 +102,55 @@ def _add_subclass_info(inner, obj, base): return inner elif type(obj) is base: return inner - module = '' - if obj.__class__.__module__ not in ('__builtin__', 'exceptions'): + module = "" + if obj.__class__.__module__ not in ("__builtin__", "exceptions"): module = '<span class="module">%s.</span>' % obj.__class__.__module__ - return '%s%s(%s)' % (module, obj.__class__.__name__, inner) + return "%s%s(%s)" % (module, obj.__class__.__name__, inner) class DebugReprGenerator(object): - def __init__(self): self._stack = [] - def _sequence_repr_maker(left, right, base=object(), limit=8): + def _sequence_repr_maker(left, right, base=object(), limit=8): # noqa: B008, B902 def proxy(self, obj, recursive): if recursive: - return _add_subclass_info(left + '...' + right, obj, base) + return _add_subclass_info(left + "..." + right, obj, base) buf = [left] have_extended_section = False for idx, item in enumerate(obj): if idx: - buf.append(', ') + buf.append(", ") if idx == limit: buf.append('<span class="extended">') have_extended_section = True buf.append(self.repr(item)) if have_extended_section: - buf.append('</span>') + buf.append("</span>") buf.append(right) - return _add_subclass_info(u''.join(buf), obj, base) + return _add_subclass_info(u"".join(buf), obj, base) + return proxy - list_repr = _sequence_repr_maker('[', ']', list) - tuple_repr = _sequence_repr_maker('(', ')', tuple) - set_repr = _sequence_repr_maker('set([', '])', set) - frozenset_repr = _sequence_repr_maker('frozenset([', '])', frozenset) - if deque is not None: - deque_repr = _sequence_repr_maker('<span class="module">collections.' - '</span>deque([', '])', deque) + list_repr = _sequence_repr_maker("[", "]", list) + tuple_repr = _sequence_repr_maker("(", ")", tuple) + set_repr = _sequence_repr_maker("set([", "])", set) + frozenset_repr = _sequence_repr_maker("frozenset([", "])", frozenset) + deque_repr = _sequence_repr_maker( + '<span class="module">collections.' "</span>deque([", "])", deque + ) del _sequence_repr_maker def regex_repr(self, obj): pattern = repr(obj.pattern) if PY2: - pattern = pattern.decode('string-escape', 'ignore') + pattern = pattern.decode("string-escape", "ignore") else: - pattern = codecs.decode(pattern, 'unicode-escape', 'ignore') - if pattern[:1] == 'u': - pattern = 'ur' + pattern[1:] + pattern = codecs.decode(pattern, "unicode-escape", "ignore") + if pattern[:1] == "u": + pattern = "ur" + pattern[1:] else: - pattern = 'r' + pattern + pattern = "r" + pattern return u're.compile(<span class="string regex">%s</span>)' % pattern def string_repr(self, obj, limit=70): @@ -158,14 +159,18 @@ class DebugReprGenerator(object): # shorten the repr when the hidden part would be at least 3 chars if len(r) - limit > 2: - buf.extend(( - escape(r[:limit]), - '<span class="extended">', escape(r[limit:]), '</span>', - )) + buf.extend( + ( + escape(r[:limit]), + '<span class="extended">', + escape(r[limit:]), + "</span>", + ) + ) else: buf.append(escape(r)) - buf.append('</span>') + buf.append("</span>") out = u"".join(buf) # if the repr looks like a standard string, add subclass info if needed @@ -177,27 +182,29 @@ class DebugReprGenerator(object): def dict_repr(self, d, recursive, limit=5): if recursive: - return _add_subclass_info(u'{...}', d, dict) - buf = ['{'] + return _add_subclass_info(u"{...}", d, dict) + buf = ["{"] have_extended_section = False for idx, (key, value) in enumerate(iteritems(d)): if idx: - buf.append(', ') + buf.append(", ") if idx == limit - 1: buf.append('<span class="extended">') have_extended_section = True - buf.append('<span class="pair"><span class="key">%s</span>: ' - '<span class="value">%s</span></span>' % - (self.repr(key), self.repr(value))) + buf.append( + '<span class="pair"><span class="key">%s</span>: ' + '<span class="value">%s</span></span>' + % (self.repr(key), self.repr(value)) + ) if have_extended_section: - buf.append('</span>') - buf.append('}') - return _add_subclass_info(u''.join(buf), d, dict) + buf.append("</span>") + buf.append("}") + return _add_subclass_info(u"".join(buf), d, dict) def object_repr(self, obj): r = repr(obj) if PY2: - r = r.decode('utf-8', 'replace') + r = r.decode("utf-8", "replace") return u'<span class="object">%s</span>' % escape(r) def dispatch_repr(self, obj, recursive): @@ -225,13 +232,14 @@ class DebugReprGenerator(object): def fallback_repr(self): try: - info = ''.join(format_exception_only(*sys.exc_info()[:2])) + info = "".join(format_exception_only(*sys.exc_info()[:2])) except Exception: # pragma: no cover - info = '?' + info = "?" if PY2: - info = info.decode('utf-8', 'ignore') - return u'<span class="brokenrepr"><broken repr (%s)>' \ - u'</span>' % escape(info.strip()) + info = info.decode("utf-8", "ignore") + return u'<span class="brokenrepr"><broken repr (%s)>' u"</span>" % escape( + info.strip() + ) def repr(self, obj): recursive = False @@ -251,7 +259,7 @@ class DebugReprGenerator(object): def dump_object(self, obj): repr = items = None if isinstance(obj, dict): - title = 'Contents of' + title = "Contents of" items = [] for key, value in iteritems(obj): if not isinstance(key, string_types): @@ -266,23 +274,24 @@ class DebugReprGenerator(object): items.append((key, self.repr(getattr(obj, key)))) except Exception: pass - title = 'Details for' - title += ' ' + object.__repr__(obj)[1:-1] + title = "Details for" + title += " " + object.__repr__(obj)[1:-1] return self.render_object_dump(items, title, repr) def dump_locals(self, d): items = [(key, self.repr(value)) for key, value in d.items()] - return self.render_object_dump(items, 'Local variables in frame') + return self.render_object_dump(items, "Local variables in frame") def render_object_dump(self, items, title, repr=None): html_items = [] for key, value in items: - html_items.append('<tr><th>%s<td><pre class=repr>%s</pre>' % - (escape(key), value)) + html_items.append( + "<tr><th>%s<td><pre class=repr>%s</pre>" % (escape(key), value) + ) if not html_items: - html_items.append('<tr><td><em>Nothing</em>') + html_items.append("<tr><td><em>Nothing</em>") return OBJECT_DUMP_HTML % { - 'title': escape(title), - 'repr': '<pre class=repr>%s</pre>' % repr if repr else '', - 'items': '\n'.join(html_items) + "title": escape(title), + "repr": "<pre class=repr>%s</pre>" % repr if repr else "", + "items": "\n".join(html_items), } diff --git a/src/werkzeug/debug/tbtools.py b/src/werkzeug/debug/tbtools.py index 9422052f..f6af4e30 100644 --- a/src/werkzeug/debug/tbtools.py +++ b/src/werkzeug/debug/tbtools.py @@ -8,30 +8,33 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import re - +import codecs +import inspect +import json import os +import re import sys -import json -import inspect import sysconfig import traceback -import codecs from tokenize import TokenError -from werkzeug.utils import cached_property, escape -from werkzeug.debug.console import Console -from werkzeug._compat import ( - range_type, PY2, text_type, string_types, - to_native, to_unicode, reraise, -) -from werkzeug.filesystem import get_filesystem_encoding +from .._compat import PY2 +from .._compat import range_type +from .._compat import reraise +from .._compat import string_types +from .._compat import text_type +from .._compat import to_native +from .._compat import to_unicode +from ..filesystem import get_filesystem_encoding +from ..utils import cached_property +from ..utils import escape +from .console import Console -_coding_re = re.compile(br'coding[:=]\s*([-\w.]+)') -_line_re = re.compile(br'^(.*?)$', re.MULTILINE) -_funcdef_re = re.compile(r'^(\s*def\s)|(.*(?<!\w)lambda(:|\s))|^(\s*@)') -UTF8_COOKIE = b'\xef\xbb\xbf' +_coding_re = re.compile(br"coding[:=]\s*([-\w.]+)") +_line_re = re.compile(br"^(.*?)$", re.MULTILINE) +_funcdef_re = re.compile(r"^(\s*def\s)|(.*(?<!\w)lambda(:|\s))|^(\s*@)") +UTF8_COOKIE = b"\xef\xbb\xbf" system_exceptions = (SystemExit, KeyboardInterrupt) try: @@ -40,7 +43,7 @@ except NameError: pass -HEADER = u'''\ +HEADER = u"""\ <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN" "http://www.w3.org/TR/html4/loose.dtd"> <html> @@ -65,8 +68,8 @@ HEADER = u'''\ </head> <body style="background-color: #fff"> <div class="debugger"> -''' -FOOTER = u'''\ +""" +FOOTER = u"""\ <div class="footer"> Brought to you by <strong class="arthur">DON'T PANIC</strong>, your friendly Werkzeug powered traceback interpreter. @@ -89,9 +92,11 @@ FOOTER = u'''\ </div> </body> </html> -''' +""" -PAGE_HTML = HEADER + u'''\ +PAGE_HTML = ( + HEADER + + u"""\ <h1>%(exception_type)s</h1> <div class="detail"> <p class="errormsg">%(exception)s</p> @@ -117,61 +122,69 @@ PAGE_HTML = HEADER + u'''\ execution (if the evalex feature is enabled), automatic pasting of the exceptions and much more.</span> </div> -''' + FOOTER + ''' +""" + + FOOTER + + """ <!-- %(plaintext_cs)s --> -''' +""" +) -CONSOLE_HTML = HEADER + u'''\ +CONSOLE_HTML = ( + HEADER + + u"""\ <h1>Interactive Console</h1> <div class="explanation"> In this console you can execute Python expressions in the context of the application. The initial namespace was created by the debugger automatically. </div> <div class="console"><div class="inner">The Console requires JavaScript.</div></div> -''' + FOOTER +""" + + FOOTER +) -SUMMARY_HTML = u'''\ +SUMMARY_HTML = u"""\ <div class="%(classes)s"> %(title)s <ul>%(frames)s</ul> %(description)s </div> -''' +""" -FRAME_HTML = u'''\ +FRAME_HTML = u"""\ <div class="frame" id="frame-%(id)d"> <h4>File <cite class="filename">"%(filename)s"</cite>, line <em class="line">%(lineno)s</em>, in <code class="function">%(function_name)s</code></h4> <div class="source %(library)s">%(lines)s</div> </div> -''' +""" -SOURCE_LINE_HTML = u'''\ +SOURCE_LINE_HTML = u"""\ <tr class="%(classes)s"> <td class=lineno>%(lineno)s</td> <td>%(code)s</td> </tr> -''' +""" def render_console_html(secret, evalex_trusted=True): return CONSOLE_HTML % { - 'evalex': 'true', - 'evalex_trusted': 'true' if evalex_trusted else 'false', - 'console': 'true', - 'title': 'Console', - 'secret': secret, - 'traceback_id': -1 + "evalex": "true", + "evalex_trusted": "true" if evalex_trusted else "false", + "console": "true", + "title": "Console", + "secret": secret, + "traceback_id": -1, } -def get_current_traceback(ignore_system_exceptions=False, - show_hidden_frames=False, skip=0): +def get_current_traceback( + ignore_system_exceptions=False, show_hidden_frames=False, skip=0 +): """Get the current exception info as `Traceback` object. Per default calling this method will reraise system exceptions such as generator exit, system exit or others. This behavior can be disabled by passing `False` @@ -192,7 +205,8 @@ def get_current_traceback(ignore_system_exceptions=False, class Line(object): """Helper for the source renderer.""" - __slots__ = ('lineno', 'code', 'in_frame', 'current') + + __slots__ = ("lineno", "code", "in_frame", "current") def __init__(self, lineno, code): self.lineno = lineno @@ -202,18 +216,18 @@ class Line(object): @property def classes(self): - rv = ['line'] + rv = ["line"] if self.in_frame: - rv.append('in-frame') + rv.append("in-frame") if self.current: - rv.append('current') + rv.append("current") return rv def render(self): return SOURCE_LINE_HTML % { - 'classes': u' '.join(self.classes), - 'lineno': self.lineno, - 'code': escape(self.code) + "classes": u" ".join(self.classes), + "lineno": self.lineno, + "code": escape(self.code), } @@ -227,7 +241,7 @@ class Traceback(object): exception_type = exc_type.__name__ if exc_type.__module__ not in {"builtins", "__builtin__", "exceptions"}: - exception_type = exc_type.__module__ + '.' + exception_type + exception_type = exc_type.__module__ + "." + exception_type self.exception_type = exception_type self.groups = [] @@ -264,38 +278,33 @@ class Traceback(object): """Log the ASCII traceback into a file object.""" if logfile is None: logfile = sys.stderr - tb = self.plaintext.rstrip() + u'\n' + tb = self.plaintext.rstrip() + u"\n" logfile.write(to_native(tb, "utf-8", "replace")) def paste(self): """Create a paste and return the paste id.""" - data = json.dumps({ - 'description': 'Werkzeug Internal Server Error', - 'public': False, - 'files': { - 'traceback.txt': { - 'content': self.plaintext - } + data = json.dumps( + { + "description": "Werkzeug Internal Server Error", + "public": False, + "files": {"traceback.txt": {"content": self.plaintext}}, } - }).encode('utf-8') + ).encode("utf-8") try: from urllib2 import urlopen except ImportError: from urllib.request import urlopen - rv = urlopen('https://api.github.com/gists', data=data) - resp = json.loads(rv.read().decode('utf-8')) + rv = urlopen("https://api.github.com/gists", data=data) + resp = json.loads(rv.read().decode("utf-8")) rv.close() - return { - 'url': resp['html_url'], - 'id': resp['id'] - } + return {"url": resp["html_url"], "id": resp["id"]} def render_summary(self, include_title=True): """Render the traceback for the interactive console.""" - title = '' - classes = ['traceback'] + title = "" + classes = ["traceback"] if not self.frames: - classes.append('noframe-traceback') + classes.append("noframe-traceback") frames = [] else: library_frames = sum(frame.is_library for frame in self.frames) @@ -304,38 +313,37 @@ class Traceback(object): if include_title: if self.is_syntax_error: - title = u'Syntax Error' + title = u"Syntax Error" else: - title = u'Traceback <em>(most recent call last)</em>:' + title = u"Traceback <em>(most recent call last)</em>:" if self.is_syntax_error: - description_wrapper = u'<pre class=syntaxerror>%s</pre>' + description_wrapper = u"<pre class=syntaxerror>%s</pre>" else: - description_wrapper = u'<blockquote>%s</blockquote>' + description_wrapper = u"<blockquote>%s</blockquote>" return SUMMARY_HTML % { - 'classes': u' '.join(classes), - 'title': u'<h3>%s</h3>' % title if title else u'', - 'frames': u'\n'.join(frames), - 'description': description_wrapper % escape(self.exception) + "classes": u" ".join(classes), + "title": u"<h3>%s</h3>" % title if title else u"", + "frames": u"\n".join(frames), + "description": description_wrapper % escape(self.exception), } - def render_full(self, evalex=False, secret=None, - evalex_trusted=True): + def render_full(self, evalex=False, secret=None, evalex_trusted=True): """Render the Full HTML page with the traceback info.""" exc = escape(self.exception) return PAGE_HTML % { - 'evalex': 'true' if evalex else 'false', - 'evalex_trusted': 'true' if evalex_trusted else 'false', - 'console': 'false', - 'title': exc, - 'exception': exc, - 'exception_type': escape(self.exception_type), - 'summary': self.render_summary(include_title=False), - 'plaintext': escape(self.plaintext), - 'plaintext_cs': re.sub('-{2,}', '-', self.plaintext), - 'traceback_id': self.id, - 'secret': secret + "evalex": "true" if evalex else "false", + "evalex_trusted": "true" if evalex_trusted else "false", + "console": "false", + "title": exc, + "exception": exc, + "exception_type": escape(self.exception_type), + "summary": self.render_summary(include_title=False), + "plaintext": escape(self.plaintext), + "plaintext_cs": re.sub("-{2,}", "-", self.plaintext), + "traceback_id": self.id, + "secret": secret, } @cached_property @@ -418,10 +426,13 @@ class Group(object): if self.info is not None: out.append(u'<li><div class="exc-divider">%s:</div>' % self.info) for frame in self.frames: - out.append(u"<li%s>%s" % ( - u' title="%s"' % escape(frame.info) if frame.info else u"", - frame.render(mark_lib=mark_lib) - )) + out.append( + u"<li%s>%s" + % ( + u' title="%s"' % escape(frame.info) if frame.info else u"", + frame.render(mark_lib=mark_lib), + ) + ) return u"\n".join(out) def render_text(self): @@ -436,7 +447,6 @@ class Group(object): class Frame(object): - """A single frame in a traceback.""" def __init__(self, exc_type, exc_value, tb): @@ -446,19 +456,19 @@ class Frame(object): self.globals = tb.tb_frame.f_globals fn = inspect.getsourcefile(tb) or inspect.getfile(tb) - if fn[-4:] in ('.pyo', '.pyc'): + if fn[-4:] in (".pyo", ".pyc"): fn = fn[:-1] # if it's a file on the file system resolve the real filename. if os.path.isfile(fn): fn = os.path.realpath(fn) self.filename = to_unicode(fn, get_filesystem_encoding()) - self.module = self.globals.get('__name__') - self.loader = self.globals.get('__loader__') + self.module = self.globals.get("__name__") + self.loader = self.globals.get("__loader__") self.code = tb.tb_frame.f_code # support for paste's traceback extensions - self.hide = self.locals.get('__traceback_hide__', False) - info = self.locals.get('__traceback_info__') + self.hide = self.locals.get("__traceback_hide__", False) + info = self.locals.get("__traceback_info__") if info is not None: info = to_unicode(info, "utf-8", "replace") self.info = info @@ -466,11 +476,11 @@ class Frame(object): def render(self, mark_lib=True): """Render a single frame in a traceback.""" return FRAME_HTML % { - 'id': self.id, - 'filename': escape(self.filename), - 'lineno': self.lineno, - 'function_name': escape(self.function_name), - 'lines': self.render_line_context(), + "id": self.id, + "filename": escape(self.filename), + "lineno": self.lineno, + "function_name": escape(self.function_name), + "lines": self.render_line_context(), "library": "library" if mark_lib and self.is_library else "", } @@ -485,7 +495,7 @@ class Frame(object): self.filename, self.lineno, self.function_name, - self.current_line.strip() + self.current_line.strip(), ) def render_line_context(self): @@ -497,34 +507,34 @@ class Frame(object): stripped_line = line.strip() prefix = len(line) - len(stripped_line) rv.append( - '<pre class="line %s"><span class="ws">%s</span>%s</pre>' % ( - cls, ' ' * prefix, escape(stripped_line) or ' ')) + '<pre class="line %s"><span class="ws">%s</span>%s</pre>' + % (cls, " " * prefix, escape(stripped_line) or " ") + ) for line in before: - render_line(line, 'before') - render_line(current, 'current') + render_line(line, "before") + render_line(current, "current") for line in after: - render_line(line, 'after') + render_line(line, "after") - return '\n'.join(rv) + return "\n".join(rv) def get_annotated_lines(self): """Helper function that returns lines with extra information.""" lines = [Line(idx + 1, x) for idx, x in enumerate(self.sourcelines)] # find function definition and mark lines - if hasattr(self.code, 'co_firstlineno'): + if hasattr(self.code, "co_firstlineno"): lineno = self.code.co_firstlineno - 1 while lineno > 0: if _funcdef_re.match(lines[lineno].code): break lineno -= 1 try: - offset = len(inspect.getblock([x.code + '\n' for x - in lines[lineno:]])) + offset = len(inspect.getblock([x.code + "\n" for x in lines[lineno:]])) except TokenError: offset = 0 - for line in lines[lineno:lineno + offset]: + for line in lines[lineno : lineno + offset]: line.in_frame = True # mark current line @@ -535,12 +545,12 @@ class Frame(object): return lines - def eval(self, code, mode='single'): + def eval(self, code, mode="single"): """Evaluate code in the context of the frame.""" if isinstance(code, string_types): if PY2 and isinstance(code, text_type): # noqa - code = UTF8_COOKIE + code.encode('utf-8') - code = compile(code, '<interactive>', mode) + code = UTF8_COOKIE + code.encode("utf-8") + code = compile(code, "<interactive>", mode) return eval(code, self.globals, self.locals) @cached_property @@ -550,9 +560,9 @@ class Frame(object): source = None if self.loader is not None: try: - if hasattr(self.loader, 'get_source'): + if hasattr(self.loader, "get_source"): source = self.loader.get_source(self.module) - elif hasattr(self.loader, 'get_source_by_code'): + elif hasattr(self.loader, "get_source_by_code"): source = self.loader.get_source_by_code(self.code) except Exception: # we munch the exception so that we don't cause troubles @@ -561,8 +571,7 @@ class Frame(object): if source is None: try: - f = open(to_native(self.filename, get_filesystem_encoding()), - mode='rb') + f = open(to_native(self.filename, get_filesystem_encoding()), mode="rb") except IOError: return [] try: @@ -576,7 +585,7 @@ class Frame(object): # yes. it should be ascii, but we don't want to reject too many # characters in the debugger if something breaks - charset = 'utf-8' + charset = "utf-8" if source.startswith(UTF8_COOKIE): source = source[3:] else: @@ -593,25 +602,21 @@ class Frame(object): try: codecs.lookup(charset) except LookupError: - charset = 'utf-8' + charset = "utf-8" - return source.decode(charset, 'replace').splitlines() + return source.decode(charset, "replace").splitlines() def get_context_lines(self, context=5): - before = self.sourcelines[self.lineno - context - 1:self.lineno - 1] - past = self.sourcelines[self.lineno:self.lineno + context] - return ( - before, - self.current_line, - past, - ) + before = self.sourcelines[self.lineno - context - 1 : self.lineno - 1] + past = self.sourcelines[self.lineno : self.lineno + context] + return (before, self.current_line, past) @property def current_line(self): try: return self.sourcelines[self.lineno - 1] except IndexError: - return u'' + return u"" @cached_property def console(self): diff --git a/src/werkzeug/exceptions.py b/src/werkzeug/exceptions.py index e2c2cb3e..0e6d1ce3 100644 --- a/src/werkzeug/exceptions.py +++ b/src/werkzeug/exceptions.py @@ -59,23 +59,23 @@ """ import sys +import werkzeug + # Because of bootstrapping reasons we need to manually patch ourselves # onto our parent module. -import werkzeug werkzeug.exceptions = sys.modules[__name__] -from werkzeug._internal import _get_environ -from werkzeug._compat import iteritems, integer_types, text_type, \ - implements_to_string - -from werkzeug.wrappers import Response +from ._compat import implements_to_string +from ._compat import integer_types +from ._compat import iteritems +from ._compat import text_type +from ._internal import _get_environ +from .wrappers import Response @implements_to_string class HTTPException(Exception): - - """ - Baseclass for all HTTP exceptions. This exception can be called as WSGI + """Baseclass for all HTTP exceptions. This exception can be called as WSGI application to render a default error page or you can catch the subclasses of it independently and render nicer error messages. """ @@ -102,6 +102,7 @@ class HTTPException(Exception): .. versionchanged:: 0.15 The description includes the wrapped exception message. """ + class newcls(cls, exception): def __init__(self, arg=None, *args, **kwargs): super(cls, self).__init__(*args, **kwargs) @@ -116,41 +117,43 @@ class HTTPException(Exception): if self.args: out += "<p><pre><code>{}: {}</code></pre></p>".format( - exception.__name__, - escape(exception.__str__(self)) + exception.__name__, escape(exception.__str__(self)) ) return out - newcls.__module__ = sys._getframe(1).f_globals.get('__name__') + newcls.__module__ = sys._getframe(1).f_globals.get("__name__") newcls.__name__ = name or cls.__name__ + exception.__name__ return newcls @property def name(self): """The status name.""" - return HTTP_STATUS_CODES.get(self.code, 'Unknown Error') + return HTTP_STATUS_CODES.get(self.code, "Unknown Error") def get_description(self, environ=None): """Get the description.""" - return u'<p>%s</p>' % escape(self.description) + return u"<p>%s</p>" % escape(self.description) def get_body(self, environ=None): """Get the HTML body.""" - return text_type(( - u'<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 3.2 Final//EN">\n' - u'<title>%(code)s %(name)s</title>\n' - u'<h1>%(name)s</h1>\n' - u'%(description)s\n' - ) % { - 'code': self.code, - 'name': escape(self.name), - 'description': self.get_description(environ) - }) + return text_type( + ( + u'<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 3.2 Final//EN">\n' + u"<title>%(code)s %(name)s</title>\n" + u"<h1>%(name)s</h1>\n" + u"%(description)s\n" + ) + % { + "code": self.code, + "name": escape(self.name), + "description": self.get_description(environ), + } + ) def get_headers(self, environ=None): """Get a list of headers.""" - return [('Content-Type', 'text/html')] + return [("Content-Type", "text/html")] def get_response(self, environ=None): """Get a response object. If one was passed to the exception @@ -179,30 +182,29 @@ class HTTPException(Exception): return response(environ, start_response) def __str__(self): - code = self.code if self.code is not None else '???' - return '%s %s: %s' % (code, self.name, self.description) + code = self.code if self.code is not None else "???" + return "%s %s: %s" % (code, self.name, self.description) def __repr__(self): - code = self.code if self.code is not None else '???' + code = self.code if self.code is not None else "???" return "<%s '%s: %s'>" % (self.__class__.__name__, code, self.name) class BadRequest(HTTPException): - """*400* `Bad Request` Raise if the browser sends something to the application the application or server cannot handle. """ + code = 400 description = ( - 'The browser (or proxy) sent a request that this server could ' - 'not understand.' + "The browser (or proxy) sent a request that this server could " + "not understand." ) class ClientDisconnected(BadRequest): - """Internal exception that is raised if Werkzeug detects a disconnected client. Since the client is already gone at that point attempting to send the error message to the client might not work and might ultimately @@ -218,7 +220,6 @@ class ClientDisconnected(BadRequest): class SecurityError(BadRequest): - """Raised if something triggers a security error. This is otherwise exactly like a bad request error. @@ -227,7 +228,6 @@ class SecurityError(BadRequest): class BadHost(BadRequest): - """Raised if the submitted host is badly formatted. .. versionadded:: 0.11.2 @@ -235,7 +235,6 @@ class BadHost(BadRequest): class Unauthorized(HTTPException): - """*401* ``Unauthorized`` Raise if the user is not authorized to access a resource. @@ -252,12 +251,13 @@ class Unauthorized(HTTPException): :param description: Override the default message used for the body of the response. """ + code = 401 description = ( - 'The server could not verify that you are authorized to access' - ' the URL requested. You either supplied the wrong credentials' + "The server could not verify that you are authorized to access" + " the URL requested. You either supplied the wrong credentials" " (e.g. a bad password), or your browser doesn't understand" - ' how to supply the credentials required.' + " how to supply the credentials required." ) def __init__(self, www_authenticate=None, description=None): @@ -269,43 +269,41 @@ class Unauthorized(HTTPException): def get_headers(self, environ=None): headers = HTTPException.get_headers(self, environ) if self.www_authenticate: - headers.append(( - 'WWW-Authenticate', - ', '.join([str(x) for x in self.www_authenticate]) - )) + headers.append( + ("WWW-Authenticate", ", ".join([str(x) for x in self.www_authenticate])) + ) return headers class Forbidden(HTTPException): - """*403* `Forbidden` Raise if the user doesn't have the permission for the requested resource but was authenticated. """ + code = 403 description = ( - 'You don\'t have the permission to access the requested resource. ' - 'It is either read-protected or not readable by the server.' + "You don't have the permission to access the requested" + " resource. It is either read-protected or not readable by the" + " server." ) class NotFound(HTTPException): - """*404* `Not Found` Raise if a resource does not exist and never existed. """ + code = 404 description = ( - 'The requested URL was not found on the server. ' - 'If you entered the URL manually please check your spelling and ' - 'try again.' + "The requested URL was not found on the server. If you entered" + " the URL manually please check your spelling and try again." ) class MethodNotAllowed(HTTPException): - """*405* `Method Not Allowed` Raise if the server used a method the resource does not handle. For @@ -315,8 +313,9 @@ class MethodNotAllowed(HTTPException): Strictly speaking the response would be invalid if you don't provide valid methods in the header which you can do with that list. """ + code = 405 - description = 'The method is not allowed for the requested URL.' + description = "The method is not allowed for the requested URL." def __init__(self, valid_methods=None, description=None): """Takes an optional list of valid http methods @@ -327,42 +326,41 @@ class MethodNotAllowed(HTTPException): def get_headers(self, environ=None): headers = HTTPException.get_headers(self, environ) if self.valid_methods: - headers.append(('Allow', ', '.join(self.valid_methods))) + headers.append(("Allow", ", ".join(self.valid_methods))) return headers class NotAcceptable(HTTPException): - """*406* `Not Acceptable` Raise if the server can't return any content conforming to the `Accept` headers of the client. """ + code = 406 description = ( - 'The resource identified by the request is only capable of ' - 'generating response entities which have content characteristics ' - 'not acceptable according to the accept headers sent in the ' - 'request.' + "The resource identified by the request is only capable of" + " generating response entities which have content" + " characteristics not acceptable according to the accept" + " headers sent in the request." ) class RequestTimeout(HTTPException): - """*408* `Request Timeout` Raise to signalize a timeout. """ + code = 408 description = ( - 'The server closed the network connection because the browser ' - 'didn\'t finish the request within the specified time.' + "The server closed the network connection because the browser" + " didn't finish the request within the specified time." ) class Conflict(HTTPException): - """*409* `Conflict` Raise to signal that a request cannot be completed because it conflicts @@ -370,107 +368,103 @@ class Conflict(HTTPException): .. versionadded:: 0.7 """ + code = 409 description = ( - 'A conflict happened while processing the request. The resource ' - 'might have been modified while the request was being processed.' + "A conflict happened while processing the request. The" + " resource might have been modified while the request was being" + " processed." ) class Gone(HTTPException): - """*410* `Gone` Raise if a resource existed previously and went away without new location. """ + code = 410 description = ( - 'The requested URL is no longer available on this server and there ' - 'is no forwarding address. If you followed a link from a foreign ' - 'page, please contact the author of this page.' + "The requested URL is no longer available on this server and" + " there is no forwarding address. If you followed a link from a" + " foreign page, please contact the author of this page." ) class LengthRequired(HTTPException): - """*411* `Length Required` Raise if the browser submitted data but no ``Content-Length`` header which is required for the kind of processing the server does. """ + code = 411 description = ( - 'A request with this method requires a valid <code>Content-' - 'Length</code> header.' + "A request with this method requires a valid <code>Content-" + "Length</code> header." ) class PreconditionFailed(HTTPException): - """*412* `Precondition Failed` Status code used in combination with ``If-Match``, ``If-None-Match``, or ``If-Unmodified-Since``. """ + code = 412 description = ( - 'The precondition on the request for the URL failed positive ' - 'evaluation.' + "The precondition on the request for the URL failed positive evaluation." ) class RequestEntityTooLarge(HTTPException): - """*413* `Request Entity Too Large` The status code one should return if the data submitted exceeded a given limit. """ + code = 413 - description = ( - 'The data value transmitted exceeds the capacity limit.' - ) + description = "The data value transmitted exceeds the capacity limit." class RequestURITooLarge(HTTPException): - """*414* `Request URI Too Large` Like *413* but for too long URLs. """ + code = 414 description = ( - 'The length of the requested URL exceeds the capacity limit ' - 'for this server. The request cannot be processed.' + "The length of the requested URL exceeds the capacity limit for" + " this server. The request cannot be processed." ) class UnsupportedMediaType(HTTPException): - """*415* `Unsupported Media Type` The status code returned if the server is unable to handle the media type the client transmitted. """ + code = 415 description = ( - 'The server does not support the media type transmitted in ' - 'the request.' + "The server does not support the media type transmitted in the request." ) class RequestedRangeNotSatisfiable(HTTPException): - """*416* `Requested Range Not Satisfiable` The client asked for an invalid part of the file. .. versionadded:: 0.7 """ + code = 416 - description = ( - 'The server cannot provide the requested range.' - ) + description = "The server cannot provide the requested range." def __init__(self, length=None, units="bytes", description=None): """Takes an optional `Content-Range` header value based on ``length`` @@ -483,27 +477,23 @@ class RequestedRangeNotSatisfiable(HTTPException): def get_headers(self, environ=None): headers = HTTPException.get_headers(self, environ) if self.length is not None: - headers.append( - ('Content-Range', '%s */%d' % (self.units, self.length))) + headers.append(("Content-Range", "%s */%d" % (self.units, self.length))) return headers class ExpectationFailed(HTTPException): - """*417* `Expectation Failed` The server cannot meet the requirements of the Expect request-header. .. versionadded:: 0.7 """ + code = 417 - description = ( - 'The server could not meet the requirements of the Expect header' - ) + description = "The server could not meet the requirements of the Expect header" class ImATeapot(HTTPException): - """*418* `I'm a teapot` The server should return this if it is a teapot and someone attempted @@ -511,54 +501,51 @@ class ImATeapot(HTTPException): .. versionadded:: 0.7 """ + code = 418 - description = ( - 'This server is a teapot, not a coffee machine' - ) + description = "This server is a teapot, not a coffee machine" class UnprocessableEntity(HTTPException): - """*422* `Unprocessable Entity` Used if the request is well formed, but the instructions are otherwise incorrect. """ + code = 422 description = ( - 'The request was well-formed but was unable to be followed ' - 'due to semantic errors.' + "The request was well-formed but was unable to be followed due" + " to semantic errors." ) class Locked(HTTPException): - """*423* `Locked` Used if the resource that is being accessed is locked. """ + code = 423 - description = ( - 'The resource that is being accessed is locked.' - ) + description = "The resource that is being accessed is locked." class FailedDependency(HTTPException): - """*424* `Failed Dependency` Used if the method could not be performed on the resource because the requested action depended on another action and that action failed. """ + code = 424 description = ( - 'The method could not be performed on the resource because the requested action ' - 'depended on another action and that action failed.' + "The method could not be performed on the resource because the" + " requested action depended on another action and that action" + " failed." ) class PreconditionRequired(HTTPException): - """*428* `Precondition Required` The server requires this request to be conditional, typically to prevent @@ -569,15 +556,15 @@ class PreconditionRequired(HTTPException): server ensures that each client has at least seen the previous revision of the resource. """ + code = 428 description = ( - 'This request is required to be conditional; try using "If-Match" ' - 'or "If-Unmodified-Since".' + "This request is required to be conditional; try using" + ' "If-Match" or "If-Unmodified-Since".' ) class TooManyRequests(HTTPException): - """*429* `Too Many Requests` The server is limiting the rate at which this user receives responses, and @@ -586,129 +573,117 @@ class TooManyRequests(HTTPException): "Retry-After" header to indicate how long the user should wait before retrying. """ + code = 429 - description = ( - 'This user has exceeded an allotted request count. Try again later.' - ) + description = "This user has exceeded an allotted request count. Try again later." class RequestHeaderFieldsTooLarge(HTTPException): - """*431* `Request Header Fields Too Large` The server refuses to process the request because the header fields are too large. One or more individual fields may be too large, or the set of all headers is too large. """ + code = 431 - description = ( - 'One or more header fields exceeds the maximum size.' - ) + description = "One or more header fields exceeds the maximum size." class UnavailableForLegalReasons(HTTPException): - """*451* `Unavailable For Legal Reasons` This status code indicates that the server is denying access to the resource as a consequence of a legal demand. """ + code = 451 - description = ( - 'Unavailable for legal reasons.' - ) + description = "Unavailable for legal reasons." class InternalServerError(HTTPException): - """*500* `Internal Server Error` Raise if an internal server error occurred. This is a good fallback if an unknown error occurred in the dispatcher. """ + code = 500 description = ( - 'The server encountered an internal error and was unable to ' - 'complete your request. Either the server is overloaded or there ' - 'is an error in the application.' + "The server encountered an internal error and was unable to" + " complete your request. Either the server is overloaded or" + " there is an error in the application." ) class NotImplemented(HTTPException): - """*501* `Not Implemented` Raise if the application does not support the action requested by the browser. """ + code = 501 - description = ( - 'The server does not support the action requested by the ' - 'browser.' - ) + description = "The server does not support the action requested by the browser." class BadGateway(HTTPException): - """*502* `Bad Gateway` If you do proxying in your application you should return this status code if you received an invalid response from the upstream server it accessed in attempting to fulfill the request. """ + code = 502 description = ( - 'The proxy server received an invalid response from an upstream ' - 'server.' + "The proxy server received an invalid response from an upstream server." ) class ServiceUnavailable(HTTPException): - """*503* `Service Unavailable` Status code you should return if a service is temporarily unavailable. """ + code = 503 description = ( - 'The server is temporarily unable to service your request due to ' - 'maintenance downtime or capacity problems. Please try again ' - 'later.' + "The server is temporarily unable to service your request due" + " to maintenance downtime or capacity problems. Please try" + " again later." ) class GatewayTimeout(HTTPException): - """*504* `Gateway Timeout` Status code you should return if a connection to an upstream server times out. """ + code = 504 - description = ( - 'The connection to an upstream server timed out.' - ) + description = "The connection to an upstream server timed out." class HTTPVersionNotSupported(HTTPException): - """*505* `HTTP Version Not Supported` The server does not support the HTTP protocol version used in the request. """ + code = 505 description = ( - 'The server does not support the HTTP protocol version used in the ' - 'request.' + "The server does not support the HTTP protocol version used in the request." ) default_exceptions = {} -__all__ = ['HTTPException'] +__all__ = ["HTTPException"] def _find_exceptions(): - for name, obj in iteritems(globals()): + for _name, obj in iteritems(globals()): try: is_http_exception = issubclass(obj, HTTPException) except TypeError: @@ -720,14 +695,14 @@ def _find_exceptions(): if old_obj is not None and issubclass(obj, old_obj): continue default_exceptions[obj.code] = obj + + _find_exceptions() del _find_exceptions class Aborter(object): - - """ - When passed a dict of code -> exception items it can be used as + """When passed a dict of code -> exception items it can be used as callable that raises exceptions. If the first argument to the callable is an integer it will be looked up in the mapping, if it's a WSGI application it will be raised in a proxy exception. @@ -746,13 +721,12 @@ class Aborter(object): if not args and not kwargs and not isinstance(code, integer_types): raise HTTPException(response=code) if code not in self.mapping: - raise LookupError('no exception for %r' % code) + raise LookupError("no exception for %r" % code) raise self.mapping[code](*args, **kwargs) def abort(status, *args, **kwargs): - ''' - Raises an :py:exc:`HTTPException` for the given status code or WSGI + """Raises an :py:exc:`HTTPException` for the given status code or WSGI application:: abort(404) # 404 Not Found @@ -766,9 +740,10 @@ def abort(status, *args, **kwargs): abort(404) abort(Response('Hello World')) - ''' + """ return _aborter(status, *args, **kwargs) + _aborter = Aborter() @@ -777,5 +752,5 @@ _aborter = Aborter() BadRequestKeyError = BadRequest.wrap(KeyError) # imported here because of circular dependencies of werkzeug.utils -from werkzeug.utils import escape -from werkzeug.http import HTTP_STATUS_CODES +from .http import HTTP_STATUS_CODES +from .utils import escape diff --git a/src/werkzeug/filesystem.py b/src/werkzeug/filesystem.py index 927f37ba..d016caea 100644 --- a/src/werkzeug/filesystem.py +++ b/src/werkzeug/filesystem.py @@ -8,41 +8,39 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ - import codecs import sys import warnings # We do not trust traditional unixes. -has_likely_buggy_unicode_filesystem = \ - sys.platform.startswith('linux') or 'bsd' in sys.platform +has_likely_buggy_unicode_filesystem = ( + sys.platform.startswith("linux") or "bsd" in sys.platform +) def _is_ascii_encoding(encoding): - """ - Given an encoding this figures out if the encoding is actually ASCII (which + """Given an encoding this figures out if the encoding is actually ASCII (which is something we don't actually want in most cases). This is necessary because ASCII comes under many names such as ANSI_X3.4-1968. """ if encoding is None: return False try: - return codecs.lookup(encoding).name == 'ascii' + return codecs.lookup(encoding).name == "ascii" except LookupError: return False class BrokenFilesystemWarning(RuntimeWarning, UnicodeWarning): - '''The warning used by Werkzeug to signal a broken filesystem. Will only be - used once per runtime.''' + """The warning used by Werkzeug to signal a broken filesystem. Will only be + used once per runtime.""" _warned_about_filesystem_encoding = False def get_filesystem_encoding(): - """ - Returns the filesystem encoding that should be used. Note that this is + """Returns the filesystem encoding that should be used. Note that this is different from the Python understanding of the filesystem encoding which might be deeply flawed. Do not use this value against Python's unicode APIs because it might be different. See :ref:`filesystem-encoding` for the exact @@ -54,13 +52,13 @@ def get_filesystem_encoding(): """ global _warned_about_filesystem_encoding rv = sys.getfilesystemencoding() - if has_likely_buggy_unicode_filesystem and not rv \ - or _is_ascii_encoding(rv): + if has_likely_buggy_unicode_filesystem and not rv or _is_ascii_encoding(rv): if not _warned_about_filesystem_encoding: warnings.warn( - 'Detected a misconfigured UNIX filesystem: Will use UTF-8 as ' - 'filesystem encoding instead of {0!r}'.format(rv), - BrokenFilesystemWarning) + "Detected a misconfigured UNIX filesystem: Will use" + " UTF-8 as filesystem encoding instead of {0!r}".format(rv), + BrokenFilesystemWarning, + ) _warned_about_filesystem_encoding = True - return 'utf-8' + return "utf-8" return rv diff --git a/src/werkzeug/formparser.py b/src/werkzeug/formparser.py index 370c0751..e626935d 100644 --- a/src/werkzeug/formparser.py +++ b/src/werkzeug/formparser.py @@ -9,8 +9,24 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import re import codecs +import re +from functools import update_wrapper +from itertools import chain +from itertools import repeat +from itertools import tee + +from ._compat import BytesIO +from ._compat import text_type +from ._compat import to_native +from .datastructures import FileStorage +from .datastructures import Headers +from .datastructures import MultiDict +from .http import parse_options_header +from .urls import url_decode_stream +from .wsgi import get_content_length +from .wsgi import get_input_stream +from .wsgi import make_line_iter # there are some platforms where SpooledTemporaryFile is not available. # In that case we need to provide a fallback. @@ -18,45 +34,43 @@ try: from tempfile import SpooledTemporaryFile except ImportError: from tempfile import TemporaryFile - SpooledTemporaryFile = None - -from itertools import chain, repeat, tee -from functools import update_wrapper -from werkzeug._compat import to_native, text_type, BytesIO -from werkzeug.urls import url_decode_stream -from werkzeug.wsgi import make_line_iter, \ - get_input_stream, get_content_length -from werkzeug.datastructures import Headers, FileStorage, MultiDict -from werkzeug.http import parse_options_header + SpooledTemporaryFile = None #: an iterator that yields empty strings -_empty_string_iter = repeat('') +_empty_string_iter = repeat("") #: a regular expression for multipart boundaries -_multipart_boundary_re = re.compile('^[ -~]{0,200}[!-~]$') +_multipart_boundary_re = re.compile("^[ -~]{0,200}[!-~]$") #: supported http encodings that are also available in python we support #: for multipart messages. -_supported_multipart_encodings = frozenset(['base64', 'quoted-printable']) +_supported_multipart_encodings = frozenset(["base64", "quoted-printable"]) -def default_stream_factory(total_content_length, filename, content_type, - content_length=None): +def default_stream_factory( + total_content_length, filename, content_type, content_length=None +): """The stream factory that is used per default.""" max_size = 1024 * 500 if SpooledTemporaryFile is not None: - return SpooledTemporaryFile(max_size=max_size, mode='wb+') + return SpooledTemporaryFile(max_size=max_size, mode="wb+") if total_content_length is None or total_content_length > max_size: - return TemporaryFile('wb+') + return TemporaryFile("wb+") return BytesIO() -def parse_form_data(environ, stream_factory=None, charset='utf-8', - errors='replace', max_form_memory_size=None, - max_content_length=None, cls=None, - silent=True): +def parse_form_data( + environ, + stream_factory=None, + charset="utf-8", + errors="replace", + max_form_memory_size=None, + max_content_length=None, + cls=None, + silent=True, +): """Parse the form data in the environ and return it as tuple in the form ``(stream, form, files)``. You should only call this method if the transport method is `POST`, `PUT`, or `PATCH`. @@ -97,9 +111,15 @@ def parse_form_data(environ, stream_factory=None, charset='utf-8', :param silent: If set to False parsing errors will not be caught. :return: A tuple in the form ``(stream, form, files)``. """ - return FormDataParser(stream_factory, charset, errors, - max_form_memory_size, max_content_length, - cls, silent).parse_from_environ(environ) + return FormDataParser( + stream_factory, + charset, + errors, + max_form_memory_size, + max_content_length, + cls, + silent, + ).parse_from_environ(environ) def exhaust_stream(f): @@ -109,7 +129,7 @@ def exhaust_stream(f): try: return f(self, stream, *args, **kwargs) finally: - exhaust = getattr(stream, 'exhaust', None) + exhaust = getattr(stream, "exhaust", None) if exhaust is not None: exhaust() else: @@ -117,11 +137,11 @@ def exhaust_stream(f): chunk = stream.read(1024 * 64) if not chunk: break + return update_wrapper(wrapper, f) class FormDataParser(object): - """This class implements parsing of form data for Werkzeug. By itself it can parse multipart and url encoded form data. It can be subclassed and extended but for most mimetypes it is a better idea to use the @@ -149,10 +169,16 @@ class FormDataParser(object): :param silent: If set to False parsing errors will not be caught. """ - def __init__(self, stream_factory=None, charset='utf-8', - errors='replace', max_form_memory_size=None, - max_content_length=None, cls=None, - silent=True): + def __init__( + self, + stream_factory=None, + charset="utf-8", + errors="replace", + max_form_memory_size=None, + max_content_length=None, + cls=None, + silent=True, + ): if stream_factory is None: stream_factory = default_stream_factory self.stream_factory = stream_factory @@ -174,11 +200,10 @@ class FormDataParser(object): :param environ: the WSGI environment to be used for parsing. :return: A tuple in the form ``(stream, form, files)``. """ - content_type = environ.get('CONTENT_TYPE', '') + content_type = environ.get("CONTENT_TYPE", "") content_length = get_content_length(environ) mimetype, options = parse_options_header(content_type) - return self.parse(get_input_stream(environ), mimetype, - content_length, options) + return self.parse(get_input_stream(environ), mimetype, content_length, options) def parse(self, stream, mimetype, content_length, options=None): """Parses the information from the given stream, mimetype, @@ -191,9 +216,11 @@ class FormDataParser(object): the multipart boundary for instance) :return: A tuple in the form ``(stream, form, files)``. """ - if self.max_content_length is not None and \ - content_length is not None and \ - content_length > self.max_content_length: + if ( + self.max_content_length is not None + and content_length is not None + and content_length > self.max_content_length + ): raise exceptions.RequestEntityTooLarge() if options is None: options = {} @@ -201,8 +228,7 @@ class FormDataParser(object): parse_func = self.get_parse_func(mimetype, options) if parse_func is not None: try: - return parse_func(self, stream, mimetype, - content_length, options) + return parse_func(self, stream, mimetype, content_length, options) except ValueError: if not self.silent: raise @@ -211,32 +237,37 @@ class FormDataParser(object): @exhaust_stream def _parse_multipart(self, stream, mimetype, content_length, options): - parser = MultiPartParser(self.stream_factory, self.charset, self.errors, - max_form_memory_size=self.max_form_memory_size, - cls=self.cls) - boundary = options.get('boundary') + parser = MultiPartParser( + self.stream_factory, + self.charset, + self.errors, + max_form_memory_size=self.max_form_memory_size, + cls=self.cls, + ) + boundary = options.get("boundary") if boundary is None: - raise ValueError('Missing boundary') + raise ValueError("Missing boundary") if isinstance(boundary, text_type): - boundary = boundary.encode('ascii') + boundary = boundary.encode("ascii") form, files = parser.parse(stream, boundary, content_length) return stream, form, files @exhaust_stream def _parse_urlencoded(self, stream, mimetype, content_length, options): - if self.max_form_memory_size is not None and \ - content_length is not None and \ - content_length > self.max_form_memory_size: + if ( + self.max_form_memory_size is not None + and content_length is not None + and content_length > self.max_form_memory_size + ): raise exceptions.RequestEntityTooLarge() - form = url_decode_stream(stream, self.charset, - errors=self.errors, cls=self.cls) + form = url_decode_stream(stream, self.charset, errors=self.errors, cls=self.cls) return stream, form, self.cls() #: mapping of mimetypes to parsing functions parse_functions = { - 'multipart/form-data': _parse_multipart, - 'application/x-www-form-urlencoded': _parse_urlencoded, - 'application/x-url-encoded': _parse_urlencoded + "multipart/form-data": _parse_multipart, + "application/x-www-form-urlencoded": _parse_urlencoded, + "application/x-url-encoded": _parse_urlencoded, } @@ -249,9 +280,9 @@ def _line_parse(line): """Removes line ending characters and returns a tuple (`stripped_line`, `is_terminated`). """ - if line[-2:] in ['\r\n', b'\r\n']: + if line[-2:] in ["\r\n", b"\r\n"]: return line[:-2], True - elif line[-1:] in ['\r', '\n', b'\r', b'\n']: + elif line[-1:] in ["\r", "\n", b"\r", b"\n"]: return line[:-1], True return line, False @@ -270,14 +301,14 @@ def parse_multipart_headers(iterable): line = to_native(line) line, line_terminated = _line_parse(line) if not line_terminated: - raise ValueError('unexpected end of line in multipart header') + raise ValueError("unexpected end of line in multipart header") if not line: break - elif line[0] in ' \t' and result: + elif line[0] in " \t" and result: key, value = result[-1] - result[-1] = (key, value + '\n ' + line[1:]) + result[-1] = (key, value + "\n " + line[1:]) else: - parts = line.split(':', 1) + parts = line.split(":", 1) if len(parts) == 2: result.append((parts[0].strip(), parts[1].strip())) @@ -286,28 +317,36 @@ def parse_multipart_headers(iterable): return Headers(result) -_begin_form = 'begin_form' -_begin_file = 'begin_file' -_cont = 'cont' -_end = 'end' +_begin_form = "begin_form" +_begin_file = "begin_file" +_cont = "cont" +_end = "end" class MultiPartParser(object): - - def __init__(self, stream_factory=None, charset='utf-8', errors='replace', - max_form_memory_size=None, cls=None, buffer_size=64 * 1024): + def __init__( + self, + stream_factory=None, + charset="utf-8", + errors="replace", + max_form_memory_size=None, + cls=None, + buffer_size=64 * 1024, + ): self.charset = charset self.errors = errors self.max_form_memory_size = max_form_memory_size - self.stream_factory = default_stream_factory if stream_factory is None else stream_factory + self.stream_factory = ( + default_stream_factory if stream_factory is None else stream_factory + ) self.cls = MultiDict if cls is None else cls # make sure the buffer size is divisible by four so that we can base64 # decode chunk by chunk - assert buffer_size % 4 == 0, 'buffer size has to be divisible by 4' + assert buffer_size % 4 == 0, "buffer size has to be divisible by 4" # also the buffer size has to be at least 1024 bytes long or long headers # will freak out the system - assert buffer_size >= 1024, 'buffer size has to be at least 1KB' + assert buffer_size >= 1024, "buffer size has to be at least 1KB" self.buffer_size = buffer_size @@ -316,8 +355,8 @@ class MultiPartParser(object): uploaded. This function strips the full path if it thinks the filename is Windows-like absolute. """ - if filename[1:3] == ':\\' or filename[:2] == '\\\\': - return filename.split('\\')[-1] + if filename[1:3] == ":\\" or filename[:2] == "\\\\": + return filename.split("\\")[-1] return filename def _find_terminator(self, iterator): @@ -331,36 +370,39 @@ class MultiPartParser(object): line = line.strip() if line: return line - return b'' + return b"" def fail(self, message): raise ValueError(message) def get_part_encoding(self, headers): - transfer_encoding = headers.get('content-transfer-encoding') - if transfer_encoding is not None and \ - transfer_encoding in _supported_multipart_encodings: + transfer_encoding = headers.get("content-transfer-encoding") + if ( + transfer_encoding is not None + and transfer_encoding in _supported_multipart_encodings + ): return transfer_encoding def get_part_charset(self, headers): # Figure out input charset for current part - content_type = headers.get('content-type') + content_type = headers.get("content-type") if content_type: mimetype, ct_params = parse_options_header(content_type) - return ct_params.get('charset', self.charset) + return ct_params.get("charset", self.charset) return self.charset def start_file_streaming(self, filename, headers, total_content_length): if isinstance(filename, bytes): filename = filename.decode(self.charset, self.errors) filename = self._fix_ie_filename(filename) - content_type = headers.get('content-type') + content_type = headers.get("content-type") try: - content_length = int(headers['content-length']) + content_length = int(headers["content-length"]) except (KeyError, ValueError): content_length = 0 - container = self.stream_factory(total_content_length, content_type, - filename, content_length) + container = self.stream_factory( + total_content_length, content_type, filename, content_length + ) return filename, container def in_memory_threshold_reached(self, bytes): @@ -368,15 +410,15 @@ class MultiPartParser(object): def validate_boundary(self, boundary): if not boundary: - self.fail('Missing boundary') + self.fail("Missing boundary") if not is_valid_multipart_boundary(boundary): - self.fail('Invalid boundary: %s' % boundary) + self.fail("Invalid boundary: %s" % boundary) if len(boundary) > self.buffer_size: # pragma: no cover # this should never happen because we check for a minimum size # of 1024 and boundaries may not be longer than 200. The only # situation when this happens is for non debug builds where # the assert is skipped. - self.fail('Boundary longer than buffer size') + self.fail("Boundary longer than buffer size") def parse_lines(self, file, boundary, content_length, cap_at_buffer=True): """Generate parts of @@ -389,31 +431,36 @@ class MultiPartParser(object): parts = ( begin_form cont* end | begin_file cont* end )* """ - next_part = b'--' + boundary - last_part = next_part + b'--' - - iterator = chain(make_line_iter(file, limit=content_length, - buffer_size=self.buffer_size, - cap_at_buffer=cap_at_buffer), - _empty_string_iter) + next_part = b"--" + boundary + last_part = next_part + b"--" + + iterator = chain( + make_line_iter( + file, + limit=content_length, + buffer_size=self.buffer_size, + cap_at_buffer=cap_at_buffer, + ), + _empty_string_iter, + ) terminator = self._find_terminator(iterator) if terminator == last_part: return elif terminator != next_part: - self.fail('Expected boundary at start of multipart data') + self.fail("Expected boundary at start of multipart data") while terminator != last_part: headers = parse_multipart_headers(iterator) - disposition = headers.get('content-disposition') + disposition = headers.get("content-disposition") if disposition is None: - self.fail('Missing Content-Disposition header') + self.fail("Missing Content-Disposition header") disposition, extra = parse_options_header(disposition) transfer_encoding = self.get_part_encoding(headers) - name = extra.get('name') - filename = extra.get('filename') + name = extra.get("name") + filename = extra.get("filename") # if no content type is given we stream into memory. A list is # used as a temporary container. @@ -425,29 +472,29 @@ class MultiPartParser(object): else: yield _begin_file, (headers, name, filename) - buf = b'' + buf = b"" for line in iterator: if not line: - self.fail('unexpected end of stream') + self.fail("unexpected end of stream") - if line[:2] == b'--': + if line[:2] == b"--": terminator = line.rstrip() if terminator in (next_part, last_part): break if transfer_encoding is not None: - if transfer_encoding == 'base64': - transfer_encoding = 'base64_codec' + if transfer_encoding == "base64": + transfer_encoding = "base64_codec" try: line = codecs.decode(line, transfer_encoding) except Exception: - self.fail('could not decode transfer encoded chunk') + self.fail("could not decode transfer encoded chunk") # we have something in the buffer from the last iteration. # this is usually a newline delimiter. if buf: yield _cont, buf - buf = b'' + buf = b"" # If the line ends with windows CRLF we write everything except # the last two bytes. In all other cases however we write @@ -458,8 +505,8 @@ class MultiPartParser(object): # truncate the stream. However we do have to make sure that # if something else than a newline is in there we write it # out. - if line[-2:] == b'\r\n': - buf = b'\r\n' + if line[-2:] == b"\r\n": + buf = b"\r\n" cutoff = -2 else: buf = line[-1:] @@ -467,12 +514,12 @@ class MultiPartParser(object): yield _cont, line[:cutoff] else: # pragma: no cover - raise ValueError('unexpected end of part') + raise ValueError("unexpected end of part") # if we have a leftover in the buffer that is not a newline # character we have to flush it, otherwise we will chop of # certain values. - if buf not in (b'', b'\r', b'\n', b'\r\n'): + if buf not in (b"", b"\r", b"\n", b"\r\n"): yield _cont, buf yield _end, None @@ -489,7 +536,8 @@ class MultiPartParser(object): is_file = True guard_memory = False filename, container = self.start_file_streaming( - filename, headers, content_length) + filename, headers, content_length + ) _write = container.write elif ellt == _begin_form: @@ -512,21 +560,24 @@ class MultiPartParser(object): elif ellt == _end: if is_file: container.seek(0) - yield ('file', - (name, FileStorage(container, filename, name, - headers=headers))) + yield ( + "file", + (name, FileStorage(container, filename, name, headers=headers)), + ) else: part_charset = self.get_part_charset(headers) - yield ('form', - (name, b''.join(container).decode( - part_charset, self.errors))) + yield ( + "form", + (name, b"".join(container).decode(part_charset, self.errors)), + ) def parse(self, file, boundary, content_length): formstream, filestream = tee( - self.parse_parts(file, boundary, content_length), 2) - form = (p[1] for p in formstream if p[0] == 'form') - files = (p[1] for p in filestream if p[0] == 'file') + self.parse_parts(file, boundary, content_length), 2 + ) + form = (p[1] for p in formstream if p[0] == "form") + files = (p[1] for p in filestream if p[0] == "file") return self.cls(form), self.cls(files) -from werkzeug import exceptions +from . import exceptions diff --git a/src/werkzeug/http.py b/src/werkzeug/http.py index 9e18ea97..af320075 100644 --- a/src/werkzeug/http.py +++ b/src/werkzeug/http.py @@ -16,54 +16,68 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ +import base64 import re import warnings -from time import time, gmtime +from datetime import datetime +from datetime import timedelta +from hashlib import md5 +from time import gmtime +from time import time + +from ._compat import integer_types +from ._compat import iteritems +from ._compat import PY2 +from ._compat import string_types +from ._compat import text_type +from ._compat import to_bytes +from ._compat import to_unicode +from ._compat import try_coerce_native +from ._internal import _cookie_parse_impl +from ._internal import _cookie_quote +from ._internal import _make_cookie_domain + try: from email.utils import parsedate_tz -except ImportError: # pragma: no cover +except ImportError: from email.Utils import parsedate_tz + try: from urllib.request import parse_http_list as _parse_list_header from urllib.parse import unquote_to_bytes as _unquote -except ImportError: # pragma: no cover - from urllib2 import parse_http_list as _parse_list_header, \ - unquote as _unquote -from datetime import datetime, timedelta -from hashlib import md5 -import base64 +except ImportError: + from urllib2 import parse_http_list as _parse_list_header + from urllib2 import unquote as _unquote -from werkzeug._internal import _cookie_quote, _make_cookie_domain, \ - _cookie_parse_impl -from werkzeug._compat import to_unicode, iteritems, text_type, \ - string_types, try_coerce_native, to_bytes, PY2, \ - integer_types - - -_cookie_charset = 'latin1' -_basic_auth_charset = 'utf-8' +_cookie_charset = "latin1" +_basic_auth_charset = "utf-8" # for explanation of "media-range", etc. see Sections 5.3.{1,2} of RFC 7231 _accept_re = re.compile( - r'''( # media-range capturing-parenthesis - [^\s;,]+ # type/subtype - (?:[ \t]*;[ \t]* # ";" - (?: # parameter non-capturing-parenthesis - [^\s;,q][^\s;,]* # token that doesn't start with "q" - | # or - q[^\s;,=][^\s;,]* # token that is more than just "q" - ) - )* # zero or more parameters - ) # end of media-range - (?:[ \t]*;[ \t]*q= # weight is a "q" parameter - (\d*(?:\.\d+)?) # qvalue capturing-parentheses - [^,]* # "extension" accept params: who cares? - )? # accept params are optional - ''', re.VERBOSE) -_token_chars = frozenset("!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" - '^_`abcdefghijklmnopqrstuvwxyz|~') + r""" + ( # media-range capturing-parenthesis + [^\s;,]+ # type/subtype + (?:[ \t]*;[ \t]* # ";" + (?: # parameter non-capturing-parenthesis + [^\s;,q][^\s;,]* # token that doesn't start with "q" + | # or + q[^\s;,=][^\s;,]* # token that is more than just "q" + ) + )* # zero or more parameters + ) # end of media-range + (?:[ \t]*;[ \t]*q= # weight is a "q" parameter + (\d*(?:\.\d+)?) # qvalue capturing-parentheses + [^,]* # "extension" accept params: who cares? + )? # accept params are optional + """, + re.VERBOSE, +) +_token_chars = frozenset( + "!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ^_`abcdefghijklmnopqrstuvwxyz|~" +) _etag_re = re.compile(r'([Ww]/)?(?:"(.*?)"|(.*?))(?:\s*,\s*|$)') -_unsafe_header_chars = set('()<>@,;:\"/[]?={} \t') -_option_header_piece_re = re.compile(r''' +_unsafe_header_chars = set('()<>@,;:"/[]?={} \t') +_option_header_piece_re = re.compile( + r""" ;\s*,?\s* # newlines were replaced with commas (?P<key> "[^"\\]*(?:\\.[^"\\]*)*" # quoted string @@ -89,100 +103,116 @@ _option_header_piece_re = re.compile(r''' )? )? \s* -''', flags=re.VERBOSE) -_option_header_start_mime_type = re.compile(r',\s*([^;,\s]+)([;,]\s*.+)?') - -_entity_headers = frozenset([ - 'allow', 'content-encoding', 'content-language', 'content-length', - 'content-location', 'content-md5', 'content-range', 'content-type', - 'expires', 'last-modified' -]) -_hop_by_hop_headers = frozenset([ - 'connection', 'keep-alive', 'proxy-authenticate', - 'proxy-authorization', 'te', 'trailer', 'transfer-encoding', - 'upgrade' -]) + """, + flags=re.VERBOSE, +) +_option_header_start_mime_type = re.compile(r",\s*([^;,\s]+)([;,]\s*.+)?") + +_entity_headers = frozenset( + [ + "allow", + "content-encoding", + "content-language", + "content-length", + "content-location", + "content-md5", + "content-range", + "content-type", + "expires", + "last-modified", + ] +) +_hop_by_hop_headers = frozenset( + [ + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailer", + "transfer-encoding", + "upgrade", + ] +) HTTP_STATUS_CODES = { - 100: 'Continue', - 101: 'Switching Protocols', - 102: 'Processing', - 200: 'OK', - 201: 'Created', - 202: 'Accepted', - 203: 'Non Authoritative Information', - 204: 'No Content', - 205: 'Reset Content', - 206: 'Partial Content', - 207: 'Multi Status', - 226: 'IM Used', # see RFC 3229 - 300: 'Multiple Choices', - 301: 'Moved Permanently', - 302: 'Found', - 303: 'See Other', - 304: 'Not Modified', - 305: 'Use Proxy', - 307: 'Temporary Redirect', - 308: 'Permanent Redirect', - 400: 'Bad Request', - 401: 'Unauthorized', - 402: 'Payment Required', # unused - 403: 'Forbidden', - 404: 'Not Found', - 405: 'Method Not Allowed', - 406: 'Not Acceptable', - 407: 'Proxy Authentication Required', - 408: 'Request Timeout', - 409: 'Conflict', - 410: 'Gone', - 411: 'Length Required', - 412: 'Precondition Failed', - 413: 'Request Entity Too Large', - 414: 'Request URI Too Long', - 415: 'Unsupported Media Type', - 416: 'Requested Range Not Satisfiable', - 417: 'Expectation Failed', - 418: 'I\'m a teapot', # see RFC 2324 - 421: 'Misdirected Request', # see RFC 7540 - 422: 'Unprocessable Entity', - 423: 'Locked', - 424: 'Failed Dependency', - 426: 'Upgrade Required', - 428: 'Precondition Required', # see RFC 6585 - 429: 'Too Many Requests', - 431: 'Request Header Fields Too Large', - 449: 'Retry With', # proprietary MS extension - 451: 'Unavailable For Legal Reasons', - 500: 'Internal Server Error', - 501: 'Not Implemented', - 502: 'Bad Gateway', - 503: 'Service Unavailable', - 504: 'Gateway Timeout', - 505: 'HTTP Version Not Supported', - 507: 'Insufficient Storage', - 510: 'Not Extended' + 100: "Continue", + 101: "Switching Protocols", + 102: "Processing", + 200: "OK", + 201: "Created", + 202: "Accepted", + 203: "Non Authoritative Information", + 204: "No Content", + 205: "Reset Content", + 206: "Partial Content", + 207: "Multi Status", + 226: "IM Used", # see RFC 3229 + 300: "Multiple Choices", + 301: "Moved Permanently", + 302: "Found", + 303: "See Other", + 304: "Not Modified", + 305: "Use Proxy", + 307: "Temporary Redirect", + 308: "Permanent Redirect", + 400: "Bad Request", + 401: "Unauthorized", + 402: "Payment Required", # unused + 403: "Forbidden", + 404: "Not Found", + 405: "Method Not Allowed", + 406: "Not Acceptable", + 407: "Proxy Authentication Required", + 408: "Request Timeout", + 409: "Conflict", + 410: "Gone", + 411: "Length Required", + 412: "Precondition Failed", + 413: "Request Entity Too Large", + 414: "Request URI Too Long", + 415: "Unsupported Media Type", + 416: "Requested Range Not Satisfiable", + 417: "Expectation Failed", + 418: "I'm a teapot", # see RFC 2324 + 421: "Misdirected Request", # see RFC 7540 + 422: "Unprocessable Entity", + 423: "Locked", + 424: "Failed Dependency", + 426: "Upgrade Required", + 428: "Precondition Required", # see RFC 6585 + 429: "Too Many Requests", + 431: "Request Header Fields Too Large", + 449: "Retry With", # proprietary MS extension + 451: "Unavailable For Legal Reasons", + 500: "Internal Server Error", + 501: "Not Implemented", + 502: "Bad Gateway", + 503: "Service Unavailable", + 504: "Gateway Timeout", + 505: "HTTP Version Not Supported", + 507: "Insufficient Storage", + 510: "Not Extended", } def wsgi_to_bytes(data): - """coerce wsgi unicode represented bytes to real ones - - """ + """coerce wsgi unicode represented bytes to real ones""" if isinstance(data, bytes): return data - return data.encode('latin1') # XXX: utf8 fallback? + return data.encode("latin1") # XXX: utf8 fallback? def bytes_to_wsgi(data): - assert isinstance(data, bytes), 'data must be bytes' + assert isinstance(data, bytes), "data must be bytes" if isinstance(data, str): return data else: - return data.decode('latin1') + return data.decode("latin1") -def quote_header_value(value, extra_chars='', allow_token=True): +def quote_header_value(value, extra_chars="", allow_token=True): """Quote a header value if necessary. .. versionadded:: 0.5 @@ -199,7 +229,7 @@ def quote_header_value(value, extra_chars='', allow_token=True): token_chars = _token_chars | set(extra_chars) if set(value).issubset(token_chars): return value - return '"%s"' % value.replace('\\', '\\\\').replace('"', '\\"') + return '"%s"' % value.replace("\\", "\\\\").replace('"', '\\"') def unquote_header_value(value, is_filename=False): @@ -223,8 +253,8 @@ def unquote_header_value(value, is_filename=False): # replace sequence below on a UNC path has the effect of turning # the leading double slash into a single slash and then # _fix_ie_filename() doesn't work correctly. See #458. - if not is_filename or value[:2] != '\\\\': - return value.replace('\\\\', '\\').replace('\\"', '"') + if not is_filename or value[:2] != "\\\\": + return value.replace("\\\\", "\\").replace('\\"', '"') return value @@ -241,8 +271,8 @@ def dump_options_header(header, options): if value is None: segments.append(key) else: - segments.append('%s=%s' % (key, quote_header_value(value))) - return '; '.join(segments) + segments.append("%s=%s" % (key, quote_header_value(value))) + return "; ".join(segments) def dump_header(iterable, allow_token=True): @@ -266,14 +296,12 @@ def dump_header(iterable, allow_token=True): if value is None: items.append(key) else: - items.append('%s=%s' % ( - key, - quote_header_value(value, allow_token=allow_token) - )) + items.append( + "%s=%s" % (key, quote_header_value(value, allow_token=allow_token)) + ) else: - items = [quote_header_value(x, allow_token=allow_token) - for x in iterable] - return ', '.join(items) + items = [quote_header_value(x, allow_token=allow_token) for x in iterable] + return ", ".join(items) def parse_list_header(value): @@ -337,10 +365,10 @@ def parse_dict_header(value, cls=dict): # XXX: validate value = bytes_to_wsgi(value) for item in _parse_list_header(value): - if '=' not in item: + if "=" not in item: result[item] = None continue - name, value = item.split('=', 1) + name, value = item.split("=", 1) if value[:1] == value[-1:] == '"': value = unquote_header_value(value[1:-1]) result[name] = value @@ -369,7 +397,7 @@ def parse_options_header(value, multiple=False): if multiple=True """ if not value: - return '', {} + return "", {} result = [] @@ -400,9 +428,7 @@ def parse_options_header(value, multiple=False): continued_encoding = encoding option = unquote_header_value(option) if option_value is not None: - option_value = unquote_header_value( - option_value, - option == 'filename') + option_value = unquote_header_value(option_value, option == "filename") if encoding is not None: option_value = _unquote(option_value).decode(encoding) if count: @@ -412,13 +438,13 @@ def parse_options_header(value, multiple=False): options[option] = options.get(option, "") + option_value else: options[option] = option_value - rest = rest[optmatch.end():] + rest = rest[optmatch.end() :] result.append(options) if multiple is False: return tuple(result) value = rest - return tuple(result) if result else ('', {}) + return tuple(result) if result else ("", {}) def parse_accept_header(value, cls=None): @@ -525,26 +551,27 @@ def parse_authorization_header(value): auth_type = auth_type.lower() except ValueError: return - if auth_type == b'basic': + if auth_type == b"basic": try: - username, password = base64.b64decode(auth_info).split(b':', 1) + username, password = base64.b64decode(auth_info).split(b":", 1) except Exception: return return Authorization( - 'basic', { - 'username': to_unicode(username, _basic_auth_charset), - 'password': to_unicode(password, _basic_auth_charset) - } + "basic", + { + "username": to_unicode(username, _basic_auth_charset), + "password": to_unicode(password, _basic_auth_charset), + }, ) - elif auth_type == b'digest': + elif auth_type == b"digest": auth_map = parse_dict_header(auth_info) - for key in 'username', 'realm', 'nonce', 'uri', 'response': + for key in "username", "realm", "nonce", "uri", "response": if key not in auth_map: return - if 'qop' in auth_map: - if not auth_map.get('nc') or not auth_map.get('cnonce'): + if "qop" in auth_map: + if not auth_map.get("nc") or not auth_map.get("cnonce"): return - return Authorization('digest', auth_map) + return Authorization("digest", auth_map) def parse_www_authenticate_header(value, on_update=None): @@ -564,8 +591,7 @@ def parse_www_authenticate_header(value, on_update=None): auth_type = auth_type.lower() except (ValueError, AttributeError): return WWWAuthenticate(value.strip().lower(), on_update=on_update) - return WWWAuthenticate(auth_type, parse_dict_header(auth_info), - on_update) + return WWWAuthenticate(auth_type, parse_dict_header(auth_info), on_update) def parse_if_range_header(value): @@ -591,19 +617,19 @@ def parse_range_header(value, make_inclusive=True): .. versionadded:: 0.7 """ - if not value or '=' not in value: + if not value or "=" not in value: return None ranges = [] last_end = 0 - units, rng = value.split('=', 1) + units, rng = value.split("=", 1) units = units.strip().lower() - for item in rng.split(','): + for item in rng.split(","): item = item.strip() - if '-' not in item: + if "-" not in item: return None - if item.startswith('-'): + if item.startswith("-"): if last_end < 0: return None try: @@ -612,8 +638,8 @@ def parse_range_header(value, make_inclusive=True): return None end = None last_end = -1 - elif '-' in item: - begin, end = item.split('-', 1) + elif "-" in item: + begin, end = item.split("-", 1) begin = begin.strip() end = end.strip() if not begin.isdigit(): @@ -650,26 +676,26 @@ def parse_content_range_header(value, on_update=None): if value is None: return None try: - units, rangedef = (value or '').strip().split(None, 1) + units, rangedef = (value or "").strip().split(None, 1) except ValueError: return None - if '/' not in rangedef: + if "/" not in rangedef: return None - rng, length = rangedef.split('/', 1) - if length == '*': + rng, length = rangedef.split("/", 1) + if length == "*": length = None elif length.isdigit(): length = int(length) else: return None - if rng == '*': + if rng == "*": return ContentRange(units, None, None, length, on_update=on_update) - elif '-' not in rng: + elif "-" not in rng: return None - start, stop = rng.split('-', 1) + start, stop = rng.split("-", 1) try: start = int(start) stop = int(stop) + 1 @@ -687,10 +713,10 @@ def quote_etag(etag, weak=False): :param weak: set to `True` to tag it "weak". """ if '"' in etag: - raise ValueError('invalid etag') + raise ValueError("invalid etag") etag = '"%s"' % etag if weak: - etag = 'W/' + etag + etag = "W/" + etag return etag @@ -709,7 +735,7 @@ def unquote_etag(etag): return None, None etag = etag.strip() weak = False - if etag.startswith(('W/', 'w/')): + if etag.startswith(("W/", "w/")): weak = True etag = etag[2:] if etag[:1] == etag[-1:] == '"': @@ -734,7 +760,7 @@ def parse_etags(value): if match is None: break is_weak, quoted, raw = match.groups() - if raw == '*': + if raw == "*": return ETags(star_tag=True) elif quoted: raw = quoted @@ -778,8 +804,7 @@ def parse_date(value): year += 2000 elif year >= 69 and year <= 99: year += 1900 - return datetime(*((year,) + t[1:7])) - \ - timedelta(seconds=t[-1] or 0) + return datetime(*((year,) + t[1:7])) - timedelta(seconds=t[-1] or 0) except (ValueError, OverflowError): return None @@ -792,12 +817,29 @@ def _dump_date(d, delim): d = d.utctimetuple() elif isinstance(d, (integer_types, float)): d = gmtime(d) - return '%s, %02d%s%s%s%s %02d:%02d:%02d GMT' % ( - ('Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun')[d.tm_wday], - d.tm_mday, delim, - ('Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', - 'Oct', 'Nov', 'Dec')[d.tm_mon - 1], - delim, str(d.tm_year), d.tm_hour, d.tm_min, d.tm_sec + return "%s, %02d%s%s%s%s %02d:%02d:%02d GMT" % ( + ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun")[d.tm_wday], + d.tm_mday, + delim, + ( + "Jan", + "Feb", + "Mar", + "Apr", + "May", + "Jun", + "Jul", + "Aug", + "Sep", + "Oct", + "Nov", + "Dec", + )[d.tm_mon - 1], + delim, + str(d.tm_year), + d.tm_hour, + d.tm_min, + d.tm_sec, ) @@ -813,7 +855,7 @@ def cookie_date(expires=None): :param expires: If provided that date is used, otherwise the current. """ - return _dump_date(expires, '-') + return _dump_date(expires, "-") def http_date(timestamp=None): @@ -827,7 +869,7 @@ def http_date(timestamp=None): :param timestamp: If provided that date is used, otherwise the current. """ - return _dump_date(timestamp, ' ') + return _dump_date(timestamp, " ") def parse_age(value=None): @@ -868,13 +910,14 @@ def dump_age(age=None): age = int(age) if age < 0: - raise ValueError('age cannot be negative') + raise ValueError("age cannot be negative") return str(age) -def is_resource_modified(environ, etag=None, data=None, last_modified=None, - ignore_if_range=True): +def is_resource_modified( + environ, etag=None, data=None, last_modified=None, ignore_if_range=True +): """Convenience method for conditional requests. :param environ: the WSGI environment of the request to be checked. @@ -889,8 +932,8 @@ def is_resource_modified(environ, etag=None, data=None, last_modified=None, if etag is None and data is not None: etag = generate_etag(data) elif data is not None: - raise TypeError('both data and etag given') - if environ['REQUEST_METHOD'] not in ('GET', 'HEAD'): + raise TypeError("both data and etag given") + if environ["REQUEST_METHOD"] not in ("GET", "HEAD"): return False unmodified = False @@ -903,16 +946,16 @@ def is_resource_modified(environ, etag=None, data=None, last_modified=None, last_modified = last_modified.replace(microsecond=0) if_range = None - if not ignore_if_range and 'HTTP_RANGE' in environ: + if not ignore_if_range and "HTTP_RANGE" in environ: # https://tools.ietf.org/html/rfc7233#section-3.2 # A server MUST ignore an If-Range header field received in a request # that does not contain a Range header field. - if_range = parse_if_range_header(environ.get('HTTP_IF_RANGE')) + if_range = parse_if_range_header(environ.get("HTTP_IF_RANGE")) if if_range is not None and if_range.date is not None: modified_since = if_range.date else: - modified_since = parse_date(environ.get('HTTP_IF_MODIFIED_SINCE')) + modified_since = parse_date(environ.get("HTTP_IF_MODIFIED_SINCE")) if modified_since and last_modified and last_modified <= modified_since: unmodified = True @@ -922,7 +965,7 @@ def is_resource_modified(environ, etag=None, data=None, last_modified=None, if if_range is not None and if_range.etag is not None: unmodified = parse_etags(if_range.etag).contains(etag) else: - if_none_match = parse_etags(environ.get('HTTP_IF_NONE_MATCH')) + if_none_match = parse_etags(environ.get("HTTP_IF_NONE_MATCH")) if if_none_match: # https://tools.ietf.org/html/rfc7232#section-3.2 # "A recipient MUST use the weak comparison function when comparing @@ -932,14 +975,14 @@ def is_resource_modified(environ, etag=None, data=None, last_modified=None, # https://tools.ietf.org/html/rfc7232#section-3.1 # "Origin server MUST use the strong comparison function when # comparing entity-tags for If-Match" - if_match = parse_etags(environ.get('HTTP_IF_MATCH')) + if_match = parse_etags(environ.get("HTTP_IF_MATCH")) if if_match: unmodified = not if_match.is_strong(etag) return not unmodified -def remove_entity_headers(headers, allowed=('expires', 'content-location')): +def remove_entity_headers(headers, allowed=("expires", "content-location")): """Remove all entity headers from a list or :class:`Headers` object. This operation works in-place. `Expires` and `Content-Location` headers are by default not removed. The reason for this is :rfc:`2616` section @@ -953,8 +996,11 @@ def remove_entity_headers(headers, allowed=('expires', 'content-location')): they are entity headers. """ allowed = set(x.lower() for x in allowed) - headers[:] = [(key, value) for key, value in headers if - not is_entity_header(key) or key.lower() in allowed] + headers[:] = [ + (key, value) + for key, value in headers + if not is_entity_header(key) or key.lower() in allowed + ] def remove_hop_by_hop_headers(headers): @@ -965,8 +1011,9 @@ def remove_hop_by_hop_headers(headers): :param headers: a list or :class:`Headers` object. """ - headers[:] = [(key, value) for key, value in headers if - not is_hop_by_hop_header(key)] + headers[:] = [ + (key, value) for key, value in headers if not is_hop_by_hop_header(key) + ] def is_entity_header(header): @@ -991,7 +1038,7 @@ def is_hop_by_hop_header(header): return header.lower() in _hop_by_hop_headers -def parse_cookie(header, charset='utf-8', errors='replace', cls=None): +def parse_cookie(header, charset="utf-8", errors="replace", cls=None): """Parse a cookie. Either from a string or WSGI environ. Per default encoding errors are ignored. If you want a different behavior @@ -1011,16 +1058,16 @@ def parse_cookie(header, charset='utf-8', errors='replace', cls=None): used. """ if isinstance(header, dict): - header = header.get('HTTP_COOKIE', '') + header = header.get("HTTP_COOKIE", "") elif header is None: - header = '' + header = "" # If the value is an unicode string it's mangled through latin1. This # is done because on PEP 3333 on Python 3 all headers are assumed latin1 # which however is incorrect for cookies, which are sent in page encoding. # As a result we if isinstance(header, text_type): - header = header.encode('latin1', 'replace') + header = header.encode("latin1", "replace") if cls is None: cls = TypeConversionDict @@ -1036,10 +1083,20 @@ def parse_cookie(header, charset='utf-8', errors='replace', cls=None): return cls(_parse_pairs()) -def dump_cookie(key, value='', max_age=None, expires=None, path='/', - domain=None, secure=False, httponly=False, - charset='utf-8', sync_expires=True, max_size=4093, - samesite=None): +def dump_cookie( + key, + value="", + max_age=None, + expires=None, + path="/", + domain=None, + secure=False, + httponly=False, + charset="utf-8", + sync_expires=True, + max_size=4093, + samesite=None, +): """Creates a new Set-Cookie header without the ``Set-Cookie`` prefix The parameters are the same as in the cookie Morsel object in the Python standard library but it accepts unicode data, too. @@ -1098,21 +1155,23 @@ def dump_cookie(key, value='', max_age=None, expires=None, path='/', expires = to_bytes(cookie_date(time() + max_age)) samesite = samesite.title() if samesite else None - if samesite not in ('Strict', 'Lax', None): + if samesite not in ("Strict", "Lax", None): raise ValueError("invalid SameSite value; must be 'Strict', 'Lax' or None") - buf = [key + b'=' + _cookie_quote(value)] + buf = [key + b"=" + _cookie_quote(value)] # XXX: In theory all of these parameters that are not marked with `None` # should be quoted. Because stdlib did not quote it before I did not # want to introduce quoting there now. - for k, v, q in ((b'Domain', domain, True), - (b'Expires', expires, False,), - (b'Max-Age', max_age, False), - (b'Secure', secure, None), - (b'HttpOnly', httponly, None), - (b'Path', path, False), - (b'SameSite', samesite, False)): + for k, v, q in ( + (b"Domain", domain, True), + (b"Expires", expires, False), + (b"Max-Age", max_age, False), + (b"Secure", secure, None), + (b"HttpOnly", httponly, None), + (b"Path", path, False), + (b"SameSite", samesite, False), + ): if q is None: if v: buf.append(k) @@ -1126,15 +1185,15 @@ def dump_cookie(key, value='', max_age=None, expires=None, path='/', v = to_bytes(text_type(v), charset) if q: v = _cookie_quote(v) - tmp += b'=' + v + tmp += b"=" + v buf.append(bytes(tmp)) # The return value will be an incorrectly encoded latin1 header on # Python 3 for consistency with the headers object and a bytestring # on Python 2 because that's how the API makes more sense. - rv = b'; '.join(buf) + rv = b"; ".join(buf) if not PY2: - rv = rv.decode('latin1') + rv = rv.decode("latin1") # Warn if the final value of the cookie is less than the limit. If the # cookie is too large, then it may be silently ignored, which can be quite @@ -1145,16 +1204,16 @@ def dump_cookie(key, value='', max_age=None, expires=None, path='/', value_size = len(value) warnings.warn( 'The "{key}" cookie is too large: the value was {value_size} bytes' - ' but the header required {extra_size} extra bytes. The final size' - ' was {cookie_size} bytes but the limit is {max_size} bytes.' - ' Browsers may silently ignore cookies larger than this.'.format( + " but the header required {extra_size} extra bytes. The final size" + " was {cookie_size} bytes but the limit is {max_size} bytes." + " Browsers may silently ignore cookies larger than this.".format( key=key, value_size=value_size, extra_size=cookie_size - value_size, cookie_size=cookie_size, - max_size=max_size + max_size=max_size, ), - stacklevel=2 + stacklevel=2, ) return rv @@ -1177,19 +1236,23 @@ def is_byte_range_valid(start, stop, length): # circular dependency fun -from werkzeug.datastructures import Accept, HeaderSet, ETags, Authorization, \ - WWWAuthenticate, TypeConversionDict, IfRange, Range, ContentRange, \ - RequestCacheControl -from werkzeug.urls import iri_to_uri - +from .datastructures import Accept +from .datastructures import Authorization +from .datastructures import ContentRange +from .datastructures import ETags +from .datastructures import HeaderSet +from .datastructures import IfRange +from .datastructures import Range +from .datastructures import RequestCacheControl +from .datastructures import TypeConversionDict +from .datastructures import WWWAuthenticate +from .urls import iri_to_uri # DEPRECATED -from werkzeug.datastructures import ( - MIMEAccept as _MIMEAccept, - CharsetAccept as _CharsetAccept, - LanguageAccept as _LanguageAccept, - Headers as _Headers, -) +from .datastructures import CharsetAccept as _CharsetAccept +from .datastructures import Headers as _Headers +from .datastructures import LanguageAccept as _LanguageAccept +from .datastructures import MIMEAccept as _MIMEAccept class MIMEAccept(_MIMEAccept): diff --git a/src/werkzeug/local.py b/src/werkzeug/local.py index af98c305..9a6088cc 100644 --- a/src/werkzeug/local.py +++ b/src/werkzeug/local.py @@ -10,8 +10,10 @@ """ import copy from functools import update_wrapper -from werkzeug.wsgi import ClosingIterator -from werkzeug._compat import PY2, implements_bool + +from ._compat import implements_bool +from ._compat import PY2 +from .wsgi import ClosingIterator # since each thread has its own greenlet we can just use those as identifiers # for the context. If greenlets are not available we fall back to the @@ -49,11 +51,11 @@ def release_local(local): class Local(object): - __slots__ = ('__storage__', '__ident_func__') + __slots__ = ("__storage__", "__ident_func__") def __init__(self): - object.__setattr__(self, '__storage__', {}) - object.__setattr__(self, '__ident_func__', get_ident) + object.__setattr__(self, "__storage__", {}) + object.__setattr__(self, "__ident_func__", get_ident) def __iter__(self): return iter(self.__storage__.items()) @@ -87,7 +89,6 @@ class Local(object): class LocalStack(object): - """This class works similar to a :class:`Local` but keeps a stack of objects instead. This is best explained with an example:: @@ -124,7 +125,8 @@ class LocalStack(object): return self._local.__ident_func__ def _set__ident_func__(self, value): - object.__setattr__(self._local, '__ident_func__', value) + object.__setattr__(self._local, "__ident_func__", value) + __ident_func__ = property(_get__ident_func__, _set__ident_func__) del _get__ident_func__, _set__ident_func__ @@ -132,13 +134,14 @@ class LocalStack(object): def _lookup(): rv = self.top if rv is None: - raise RuntimeError('object unbound') + raise RuntimeError("object unbound") return rv + return LocalProxy(_lookup) def push(self, obj): """Pushes a new item to the stack""" - rv = getattr(self._local, 'stack', None) + rv = getattr(self._local, "stack", None) if rv is None: self._local.stack = rv = [] rv.append(obj) @@ -148,7 +151,7 @@ class LocalStack(object): """Removes the topmost item from the stack, will return the old value or `None` if the stack was already empty. """ - stack = getattr(self._local, 'stack', None) + stack = getattr(self._local, "stack", None) if stack is None: return None elif len(stack) == 1: @@ -169,7 +172,6 @@ class LocalStack(object): class LocalManager(object): - """Local objects cannot manage themselves. For that you need a local manager. You can pass a local manager multiple locals or add them later by appending them to `manager.locals`. Every time the manager cleans up, @@ -196,7 +198,7 @@ class LocalManager(object): if ident_func is not None: self.ident_func = ident_func for local in self.locals: - object.__setattr__(local, '__ident_func__', ident_func) + object.__setattr__(local, "__ident_func__", ident_func) else: self.ident_func = get_ident @@ -224,8 +226,10 @@ class LocalManager(object): """Wrap a WSGI application so that cleaning up happens after request end. """ + def application(environ, start_response): return ClosingIterator(app(environ, start_response), self.cleanup) + return application def middleware(self, func): @@ -244,15 +248,11 @@ class LocalManager(object): return update_wrapper(self.make_middleware(func), func) def __repr__(self): - return '<%s storages: %d>' % ( - self.__class__.__name__, - len(self.locals) - ) + return "<%s storages: %d>" % (self.__class__.__name__, len(self.locals)) @implements_bool class LocalProxy(object): - """Acts as a proxy for a werkzeug local. Forwards all operations to a proxied object. The only operations not supported for forwarding are right handed operands and any kind of assignment. @@ -287,40 +287,41 @@ class LocalProxy(object): .. versionchanged:: 0.6.1 The class can be instantiated with a callable as well now. """ - __slots__ = ('__local', '__dict__', '__name__', '__wrapped__') + + __slots__ = ("__local", "__dict__", "__name__", "__wrapped__") def __init__(self, local, name=None): - object.__setattr__(self, '_LocalProxy__local', local) - object.__setattr__(self, '__name__', name) - if callable(local) and not hasattr(local, '__release_local__'): + object.__setattr__(self, "_LocalProxy__local", local) + object.__setattr__(self, "__name__", name) + if callable(local) and not hasattr(local, "__release_local__"): # "local" is a callable that is not an instance of Local or # LocalManager: mark it as a wrapped function. - object.__setattr__(self, '__wrapped__', local) + object.__setattr__(self, "__wrapped__", local) def _get_current_object(self): """Return the current object. This is useful if you want the real object behind the proxy at a time for performance reasons or because you want to pass the object into a different context. """ - if not hasattr(self.__local, '__release_local__'): + if not hasattr(self.__local, "__release_local__"): return self.__local() try: return getattr(self.__local, self.__name__) except AttributeError: - raise RuntimeError('no object bound to %s' % self.__name__) + raise RuntimeError("no object bound to %s" % self.__name__) @property def __dict__(self): try: return self._get_current_object().__dict__ except RuntimeError: - raise AttributeError('__dict__') + raise AttributeError("__dict__") def __repr__(self): try: obj = self._get_current_object() except RuntimeError: - return '<%s unbound>' % self.__class__.__name__ + return "<%s unbound>" % self.__class__.__name__ return repr(obj) def __bool__(self): @@ -342,7 +343,7 @@ class LocalProxy(object): return [] def __getattr__(self, name): - if name == '__members__': + if name == "__members__": return dir(self._get_current_object()) return getattr(self._get_current_object(), name) diff --git a/src/werkzeug/middleware/http_proxy.py b/src/werkzeug/middleware/http_proxy.py index 0890cf00..bfdc0712 100644 --- a/src/werkzeug/middleware/http_proxy.py +++ b/src/werkzeug/middleware/http_proxy.py @@ -7,7 +7,6 @@ Basic HTTP Proxy :copyright: 2007 Pallets :license: BSD-3-Clause """ - import socket from ..datastructures import EnvironHeaders @@ -117,7 +116,7 @@ class ProxyMiddleware(object): if opts["remove_prefix"]: remote_path = "%s/%s" % ( target.path.rstrip("/"), - remote_path[len(prefix):].lstrip("/"), + remote_path[len(prefix) :].lstrip("/"), ) content_length = environ.get("CONTENT_LENGTH") @@ -179,7 +178,7 @@ class ProxyMiddleware(object): resp = con.getresponse() except socket.error: - from werkzeug.exceptions import BadGateway + from ..exceptions import BadGateway return BadGateway()(environ, start_response) diff --git a/src/werkzeug/middleware/lint.py b/src/werkzeug/middleware/lint.py index 0c71f9ca..15372819 100644 --- a/src/werkzeug/middleware/lint.py +++ b/src/werkzeug/middleware/lint.py @@ -164,7 +164,7 @@ class GuardedIterator(object): content_length = headers.get("content-length", type=int) if status_code == 304: - for key, value in headers: + for key, _value in headers: key = key.lower() if key not in ("expires", "content-location") and is_entity_header( key diff --git a/src/werkzeug/middleware/profiler.py b/src/werkzeug/middleware/profiler.py index 64041f61..8e2edc20 100644 --- a/src/werkzeug/middleware/profiler.py +++ b/src/werkzeug/middleware/profiler.py @@ -12,6 +12,7 @@ that may be slowing down your application. :license: BSD-3-Clause """ from __future__ import print_function + import os.path import sys import time diff --git a/src/werkzeug/middleware/proxy_fix.py b/src/werkzeug/middleware/proxy_fix.py index 6fcd2ef6..dc1dacc8 100644 --- a/src/werkzeug/middleware/proxy_fix.py +++ b/src/werkzeug/middleware/proxy_fix.py @@ -141,7 +141,7 @@ class ProxyFix(object): "'get_remote_addr' is deprecated as of version 0.15 and" " will be removed in version 1.0. It is now handled" " internally for each header.", - DeprecationWarning + DeprecationWarning, ) return self._get_trusted_comma(self.x_for, ",".join(forwarded_for)) diff --git a/src/werkzeug/middleware/shared_data.py b/src/werkzeug/middleware/shared_data.py index 5ea3f87e..a902281d 100644 --- a/src/werkzeug/middleware/shared_data.py +++ b/src/werkzeug/middleware/shared_data.py @@ -8,7 +8,6 @@ Serve Shared Static Files :copyright: 2007 Pallets :license: BSD-3-Clause """ - import mimetypes import os import posixpath @@ -219,7 +218,7 @@ class SharedDataMiddleware(object): search_path += "/" if path.startswith(search_path): - real_filename, file_loader = loader(path[len(search_path):]) + real_filename, file_loader = loader(path[len(search_path) :]) if file_loader is not None: break diff --git a/src/werkzeug/posixemulation.py b/src/werkzeug/posixemulation.py index dbf909de..696b4562 100644 --- a/src/werkzeug/posixemulation.py +++ b/src/werkzeug/posixemulation.py @@ -17,21 +17,18 @@ r""" :copyright: 2007 Pallets :license: BSD-3-Clause """ -import sys -import os import errno -import time +import os import random +import sys +import time from ._compat import to_unicode from .filesystem import get_filesystem_encoding - can_rename_open_file = False -if os.name == 'nt': # pragma: no cover - _rename = lambda src, dst: False - _rename_atomic = lambda src, dst: False +if os.name == "nt": try: import ctypes @@ -47,8 +44,9 @@ if os.name == 'nt': # pragma: no cover retry = 0 rv = False while not rv and retry < 100: - rv = _MoveFileEx(src, dst, _MOVEFILE_REPLACE_EXISTING - | _MOVEFILE_WRITE_THROUGH) + rv = _MoveFileEx( + src, dst, _MOVEFILE_REPLACE_EXISTING | _MOVEFILE_WRITE_THROUGH + ) if not rv: time.sleep(0.001) retry += 1 @@ -62,16 +60,21 @@ if os.name == 'nt': # pragma: no cover can_rename_open_file = True def _rename_atomic(src, dst): - ta = _CreateTransaction(None, 0, 0, 0, 0, 1000, 'Werkzeug rename') + ta = _CreateTransaction(None, 0, 0, 0, 0, 1000, "Werkzeug rename") if ta == -1: return False try: retry = 0 rv = False while not rv and retry < 100: - rv = _MoveFileTransacted(src, dst, None, None, - _MOVEFILE_REPLACE_EXISTING - | _MOVEFILE_WRITE_THROUGH, ta) + rv = _MoveFileTransacted( + src, + dst, + None, + None, + _MOVEFILE_REPLACE_EXISTING | _MOVEFILE_WRITE_THROUGH, + ta, + ) if rv: rv = _CommitTransaction(ta) break @@ -81,8 +84,14 @@ if os.name == 'nt': # pragma: no cover return rv finally: _CloseHandle(ta) + except Exception: - pass + + def _rename(src, dst): + return False + + def _rename_atomic(src, dst): + return False def rename(src, dst): # Try atomic or pseudo-atomic rename @@ -101,6 +110,8 @@ if os.name == 'nt': # pragma: no cover os.unlink(old) except Exception: pass + + else: rename = os.rename can_rename_open_file = True diff --git a/src/werkzeug/routing.py b/src/werkzeug/routing.py index 3e349814..875c0154 100644 --- a/src/werkzeug/routing.py +++ b/src/werkzeug/routing.py @@ -96,30 +96,44 @@ :license: BSD-3-Clause """ import difflib -import re -import uuid -import posixpath import dis +import posixpath +import re import sys import types - +import uuid from functools import partial from pprint import pformat from threading import Lock -from werkzeug.urls import url_encode, url_quote, url_join, _fast_url_quote -from werkzeug.utils import redirect, format_string -from werkzeug.exceptions import HTTPException, NotFound, MethodNotAllowed, \ - BadHost -from werkzeug._internal import _get_environ, _encode_idna -from werkzeug._compat import itervalues, iteritems, to_unicode, to_bytes, \ - text_type, string_types, native_string_result, \ - implements_to_string, wsgi_decoding_dance -from werkzeug.datastructures import ImmutableDict, MultiDict -from werkzeug.utils import cached_property -from werkzeug.wsgi import get_host - -_rule_re = re.compile(r''' +from ._compat import implements_to_string +from ._compat import iteritems +from ._compat import itervalues +from ._compat import native_string_result +from ._compat import string_types +from ._compat import text_type +from ._compat import to_bytes +from ._compat import to_unicode +from ._compat import wsgi_decoding_dance +from ._internal import _encode_idna +from ._internal import _get_environ +from .datastructures import ImmutableDict +from .datastructures import MultiDict +from .exceptions import BadHost +from .exceptions import HTTPException +from .exceptions import MethodNotAllowed +from .exceptions import NotFound +from .urls import _fast_url_quote +from .urls import url_encode +from .urls import url_join +from .urls import url_quote +from .utils import cached_property +from .utils import format_string +from .utils import redirect +from .wsgi import get_host + +_rule_re = re.compile( + r""" (?P<static>[^<]*) # static rule data < (?: @@ -129,9 +143,12 @@ _rule_re = re.compile(r''' )? (?P<variable>[a-zA-Z_][a-zA-Z0-9_]*) # variable name > -''', re.VERBOSE) -_simple_rule_re = re.compile(r'<([^>]+)>') -_converter_args_re = re.compile(r''' + """, + re.VERBOSE, +) +_simple_rule_re = re.compile(r"<([^>]+)>") +_converter_args_re = re.compile( + r""" ((?P<name>\w+)\s*=\s*)? (?P<value> True|False| @@ -141,14 +158,12 @@ _converter_args_re = re.compile(r''' [\w\d_.]+| [urUR]?(?P<stringval>"[^"]*?"|'[^']*') )\s*, -''', re.VERBOSE | re.UNICODE) + """, + re.VERBOSE | re.UNICODE, +) -_PYTHON_CONSTANTS = { - 'None': None, - 'True': True, - 'False': False -} +_PYTHON_CONSTANTS = {"None": None, "True": True, "False": False} def _pythonize(value): @@ -159,25 +174,25 @@ def _pythonize(value): return convert(value) except ValueError: pass - if value[:1] == value[-1:] and value[0] in '"\'': + if value[:1] == value[-1:] and value[0] in "\"'": value = value[1:-1] return text_type(value) def parse_converter_args(argstr): - argstr += ',' + argstr += "," args = [] kwargs = {} for item in _converter_args_re.finditer(argstr): - value = item.group('stringval') + value = item.group("stringval") if value is None: - value = item.group('value') + value = item.group("value") value = _pythonize(value) - if not item.group('name'): + if not item.group("name"): args.append(value) else: - name = item.group('name') + name = item.group("name") kwargs[name] = value return tuple(args), kwargs @@ -199,24 +214,23 @@ def parse_rule(rule): if m is None: break data = m.groupdict() - if data['static']: - yield None, None, data['static'] - variable = data['variable'] - converter = data['converter'] or 'default' + if data["static"]: + yield None, None, data["static"] + variable = data["variable"] + converter = data["converter"] or "default" if variable in used_names: - raise ValueError('variable name %r used twice.' % variable) + raise ValueError("variable name %r used twice." % variable) used_names.add(variable) - yield converter, data['args'] or None, variable + yield converter, data["args"] or None, variable pos = m.end() if pos < end: remaining = rule[pos:] - if '>' in remaining or '<' in remaining: - raise ValueError('malformed url rule: %r' % rule) + if ">" in remaining or "<" in remaining: + raise ValueError("malformed url rule: %r" % rule) yield None, None, remaining class RoutingException(Exception): - """Special exceptions that require the application to redirect, notifying about missing urls, etc. @@ -225,12 +239,12 @@ class RoutingException(Exception): class RequestRedirect(HTTPException, RoutingException): - """Raise if the map requests a redirect. This is for example the case if `strict_slashes` are activated and an url that requires a trailing slash. The attribute `new_url` contains the absolute destination url. """ + code = 308 def __init__(self, new_url): @@ -242,12 +256,10 @@ class RequestRedirect(HTTPException, RoutingException): class RequestSlash(RoutingException): - """Internal exception.""" -class RequestAliasRedirect(RoutingException): - +class RequestAliasRedirect(RoutingException): # noqa: B903 """This rule is an alias and wants to redirect to the canonical URL.""" def __init__(self, matched_values): @@ -256,7 +268,6 @@ class RequestAliasRedirect(RoutingException): @implements_to_string class BuildError(RoutingException, LookupError): - """Raised if the build system cannot find a URL for an endpoint with the values provided. """ @@ -274,55 +285,54 @@ class BuildError(RoutingException, LookupError): def closest_rule(self, adapter): def _score_rule(rule): - return sum([ - 0.98 * difflib.SequenceMatcher( - None, rule.endpoint, self.endpoint - ).ratio(), - 0.01 * bool(set(self.values or ()).issubset(rule.arguments)), - 0.01 * bool(rule.methods and self.method in rule.methods) - ]) + return sum( + [ + 0.98 + * difflib.SequenceMatcher( + None, rule.endpoint, self.endpoint + ).ratio(), + 0.01 * bool(set(self.values or ()).issubset(rule.arguments)), + 0.01 * bool(rule.methods and self.method in rule.methods), + ] + ) if adapter and adapter.map._rules: return max(adapter.map._rules, key=_score_rule) def __str__(self): message = [] - message.append('Could not build url for endpoint %r' % self.endpoint) + message.append("Could not build url for endpoint %r" % self.endpoint) if self.method: - message.append(' (%r)' % self.method) + message.append(" (%r)" % self.method) if self.values: - message.append(' with values %r' % sorted(self.values.keys())) - message.append('.') + message.append(" with values %r" % sorted(self.values.keys())) + message.append(".") if self.suggested: if self.endpoint == self.suggested.endpoint: if self.method and self.method not in self.suggested.methods: - message.append(' Did you mean to use methods %r?' % sorted( - self.suggested.methods - )) + message.append( + " Did you mean to use methods %r?" + % sorted(self.suggested.methods) + ) missing_values = self.suggested.arguments.union( set(self.suggested.defaults or ()) ) - set(self.values.keys()) if missing_values: message.append( - ' Did you forget to specify values %r?' % - sorted(missing_values) + " Did you forget to specify values %r?" % sorted(missing_values) ) else: - message.append( - ' Did you mean %r instead?' % self.suggested.endpoint - ) - return u''.join(message) + message.append(" Did you mean %r instead?" % self.suggested.endpoint) + return u"".join(message) class ValidationError(ValueError): - """Validation error. If a rule converter raises this exception the rule does not match the current URL and the next URL is tried. """ class RuleFactory(object): - """As soon as you have more complex URL setups it's a good idea to use rule factories to avoid repetitive tasks. Some of them are builtin, others can be added by subclassing `RuleFactory` and overriding `get_rules`. @@ -335,7 +345,6 @@ class RuleFactory(object): class Subdomain(RuleFactory): - """All URLs provided by this factory have the subdomain set to a specific domain. For example if you want to use the subdomain for the current language this can be a good setup:: @@ -367,7 +376,6 @@ class Subdomain(RuleFactory): class Submount(RuleFactory): - """Like `Subdomain` but prefixes the URL rule with a given string:: url_map = Map([ @@ -382,7 +390,7 @@ class Submount(RuleFactory): """ def __init__(self, path, rules): - self.path = path.rstrip('/') + self.path = path.rstrip("/") self.rules = rules def get_rules(self, map): @@ -394,7 +402,6 @@ class Submount(RuleFactory): class EndpointPrefix(RuleFactory): - """Prefixes all endpoints (which must be strings for this factory) with another string. This can be useful for sub applications:: @@ -420,7 +427,6 @@ class EndpointPrefix(RuleFactory): class RuleTemplate(object): - """Returns copies of the rules wrapped and expands string templates in the endpoint, rule, defaults or subdomain sections. @@ -447,7 +453,6 @@ class RuleTemplate(object): class RuleTemplateFactory(RuleFactory): - """A factory that fills in template variables into rules. Used by `RuleTemplate` internally. @@ -480,13 +485,12 @@ class RuleTemplateFactory(RuleFactory): rule.methods, rule.build_only, new_endpoint, - rule.strict_slashes + rule.strict_slashes, ) @implements_to_string class Rule(RuleFactory): - """A Rule represents one URL pattern. There are some options for `Rule` that change the way it behaves and are passed to the `Rule` constructor. Note that besides the rule-string all arguments *must* be keyword arguments @@ -600,13 +604,23 @@ class Rule(RuleFactory): The `alias` and `host` parameters were added. """ - def __init__(self, string, defaults=None, subdomain=None, methods=None, - build_only=False, endpoint=None, strict_slashes=None, - redirect_to=None, alias=False, host=None): - if not string.startswith('/'): - raise ValueError('urls must start with a leading slash') + def __init__( + self, + string, + defaults=None, + subdomain=None, + methods=None, + build_only=False, + endpoint=None, + strict_slashes=None, + redirect_to=None, + alias=False, + host=None, + ): + if not string.startswith("/"): + raise ValueError("urls must start with a leading slash") self.rule = string - self.is_leaf = not string.endswith('/') + self.is_leaf = not string.endswith("/") self.map = None self.strict_slashes = strict_slashes @@ -619,10 +633,10 @@ class Rule(RuleFactory): self.methods = None else: if isinstance(methods, str): - raise TypeError('param `methods` should be `Iterable[str]`, not `str`') + raise TypeError("param `methods` should be `Iterable[str]`, not `str`") self.methods = set([x.upper() for x in methods]) - if 'HEAD' not in self.methods and 'GET' in self.methods: - self.methods.add('HEAD') + if "HEAD" not in self.methods and "GET" in self.methods: + self.methods.add("HEAD") self.endpoint = endpoint self.redirect_to = redirect_to @@ -657,11 +671,17 @@ class Rule(RuleFactory): defaults = None if self.defaults: defaults = dict(self.defaults) - return dict(defaults=defaults, subdomain=self.subdomain, - methods=self.methods, build_only=self.build_only, - endpoint=self.endpoint, strict_slashes=self.strict_slashes, - redirect_to=self.redirect_to, alias=self.alias, - host=self.host) + return dict( + defaults=defaults, + subdomain=self.subdomain, + methods=self.methods, + build_only=self.build_only, + endpoint=self.endpoint, + strict_slashes=self.strict_slashes, + redirect_to=self.redirect_to, + alias=self.alias, + host=self.host, + ) def get_rules(self, map): yield self @@ -681,8 +701,7 @@ class Rule(RuleFactory): :internal: """ if self.map is not None and not rebind: - raise RuntimeError('url rule %r already bound to map %r' % - (self, self.map)) + raise RuntimeError("url rule %r already bound to map %r" % (self, self.map)) self.map = map if self.strict_slashes is None: self.strict_slashes = map.strict_slashes @@ -696,17 +715,17 @@ class Rule(RuleFactory): .. versionadded:: 0.9 """ if converter_name not in self.map.converters: - raise LookupError('the converter %r does not exist' % converter_name) + raise LookupError("the converter %r does not exist" % converter_name) return self.map.converters[converter_name](self.map, *args, **kwargs) def compile(self): """Compiles the regular expression and stores it.""" - assert self.map is not None, 'rule not bound' + assert self.map is not None, "rule not bound" if self.map.host_matching: - domain_rule = self.host or '' + domain_rule = self.host or "" else: - domain_rule = self.subdomain or '' + domain_rule = self.subdomain or "" self._trace = [] self._converters = {} @@ -720,7 +739,7 @@ class Rule(RuleFactory): if converter is None: regex_parts.append(re.escape(variable)) self._trace.append((False, variable)) - for part in variable.split('/'): + for part in variable.split("/"): if part: self._static_weights.append((index, -len(part))) else: @@ -729,9 +748,8 @@ class Rule(RuleFactory): else: c_args = () c_kwargs = {} - convobj = self.get_converter( - variable, converter, c_args, c_kwargs) - regex_parts.append('(?P<%s>%s)' % (variable, convobj.regex)) + convobj = self.get_converter(variable, converter, c_args, c_kwargs) + regex_parts.append("(?P<%s>%s)" % (variable, convobj.regex)) self._converters[variable] = convobj self._trace.append((True, variable)) self._argument_weights.append(convobj.weight) @@ -739,21 +757,22 @@ class Rule(RuleFactory): index = index + 1 _build_regex(domain_rule) - regex_parts.append('\\|') - self._trace.append((False, '|')) - _build_regex(self.rule if self.is_leaf else self.rule.rstrip('/')) + regex_parts.append("\\|") + self._trace.append((False, "|")) + _build_regex(self.rule if self.is_leaf else self.rule.rstrip("/")) if not self.is_leaf: - self._trace.append((False, '/')) + self._trace.append((False, "/")) self._build = self._compile_builder(False) self._build_unknown = self._compile_builder(True) if self.build_only: return - regex = r'^%s%s$' % ( - u''.join(regex_parts), + regex = r"^%s%s$" % ( + u"".join(regex_parts), (not self.is_leaf or not self.strict_slashes) - and '(?<!/)(?P<__suffix__>/?)' or '' + and "(?<!/)(?P<__suffix__>/?)" + or "", ) self._regex = re.compile(regex, re.UNICODE) @@ -776,15 +795,19 @@ class Rule(RuleFactory): # slash and strict slashes enabled. raise an exception that # tells the map to redirect to the same url but with a # trailing slash - if self.strict_slashes and not self.is_leaf and \ - not groups.pop('__suffix__') and \ - (method is None or self.methods is None - or method in self.methods): + if ( + self.strict_slashes + and not self.is_leaf + and not groups.pop("__suffix__") + and ( + method is None or self.methods is None or method in self.methods + ) + ): raise RequestSlash() # if we are not in strict slashes mode we have to remove # a __suffix__ elif not self.strict_slashes: - del groups['__suffix__'] + del groups["__suffix__"] result = {} for name, value in iteritems(groups): @@ -802,7 +825,7 @@ class Rule(RuleFactory): return result class BuilderCompiler: - JOIN_EMPTY = ''.join + JOIN_EMPTY = "".join if sys.version_info >= (3, 6): OPARG_SIZE = 256 OPARG_VARI = False @@ -876,14 +899,14 @@ class Rule(RuleFactory): if op is not None: new.append((op, elem)) continue - if elem == '': + if elem == "": continue if not new or new[-1][0] is not None: new.append((op, elem)) continue new[-1] = (None, new[-1][1] + elem) if not new: - new.append((None, '')) + new.append((None, "")) return new def build_op(self, op, arg=None): @@ -891,20 +914,17 @@ class Rule(RuleFactory): if isinstance(op, str): op = dis.opmap[op] if arg is None and op >= dis.HAVE_ARGUMENT: - raise ValueError( - "Operation requires an argument: %s" % dis.opname[op]) + raise ValueError("Operation requires an argument: %s" % dis.opname[op]) if arg is not None and op < dis.HAVE_ARGUMENT: - raise ValueError( - "Operation takes no argument: %s" % dis.opname[op]) + raise ValueError("Operation takes no argument: %s" % dis.opname[op]) if arg is None: arg = 0 # Python 3.6 changed the argument to an 8-bit integer, so this # could be a practical consideration if arg >= self.OPARG_SIZE: - return ( - self.build_op('EXTENDED_ARG', arg // self.OPARG_SIZE) - + self.build_op(op, arg % self.OPARG_SIZE) - ) + return self.build_op( + "EXTENDED_ARG", arg // self.OPARG_SIZE + ) + self.build_op(op, arg % self.OPARG_SIZE) if not self.OPARG_VARI: return bytearray((op, arg)) elif op >= dis.HAVE_ARGUMENT: @@ -918,58 +938,57 @@ class Rule(RuleFactory): already be immediately below the string elements on the stack. """ - if 'BUILD_STRING' in dis.opmap: - return self.build_op('BUILD_STRING', n) + if "BUILD_STRING" in dis.opmap: + return self.build_op("BUILD_STRING", n) else: - return ( - self.build_op('BUILD_TUPLE', n) - + self.build_op('CALL_FUNCTION', 1) + return self.build_op("BUILD_TUPLE", n) + self.build_op( + "CALL_FUNCTION", 1 ) def emit_build( - self, ind, opl, append_unknown=False, encode_query_vars=None, - kwargs=None + self, ind, opl, append_unknown=False, encode_query_vars=None, kwargs=None ): - ops = b'' + ops = b"" n = len(opl) stack = 0 stack_overhead = 0 for op, elem in opl: if op is None: - ops += self.build_op('LOAD_CONST', self.get_const(elem)) + ops += self.build_op("LOAD_CONST", self.get_const(elem)) stack_overhead = 0 continue - ops += self.build_op('LOAD_CONST', self.get_const(op)) - ops += self.build_op('LOAD_FAST', self.get_var(elem)) - ops += self.build_op('CALL_FUNCTION', 1) + ops += self.build_op("LOAD_CONST", self.get_const(op)) + ops += self.build_op("LOAD_FAST", self.get_var(elem)) + ops += self.build_op("CALL_FUNCTION", 1) stack_overhead = 2 stack += len(opl) peak_stack = stack + stack_overhead dont_build_string = False - needs_build_string = 'BUILD_STRING' not in dis.opmap + needs_build_string = "BUILD_STRING" not in dis.opmap if n <= 1: dont_build_string = True needs_build_string = False if append_unknown: - if 'BUILD_STRING' not in dis.opmap: + if "BUILD_STRING" not in dis.opmap: needs_build_string = True - ops = self.build_op( - 'LOAD_CONST', self.get_const(self.JOIN_EMPTY)) + ops - ops += self.build_op('LOAD_FAST', kwargs) + ops = ( + self.build_op("LOAD_CONST", self.get_const(self.JOIN_EMPTY)) + + ops + ) + ops += self.build_op("LOAD_FAST", kwargs) # assemble this in its own buffers because we need to # jump over it uops = bytearray() # run if kwargs. TOS=kwargs - uops += self.build_op( - 'LOAD_CONST', self.get_const(encode_query_vars)) - uops += self.build_op('ROT_TWO') - uops += self.build_op('CALL_FUNCTION', 1) - uops += self.build_op('LOAD_CONST', self.get_const('?')) - uops += self.build_op('ROT_TWO') + uops += self.build_op("LOAD_CONST", self.get_const(encode_query_vars)) + uops += self.build_op("ROT_TWO") + uops += self.build_op("CALL_FUNCTION", 1) + uops += self.build_op("LOAD_CONST", self.get_const("?")) + uops += self.build_op("ROT_TWO") if dont_build_string: uops += self.build_string(n + 2) @@ -977,23 +996,23 @@ class Rule(RuleFactory): if not dont_build_string: # if we're going to build a string, we need to pad out to # a constant length - nops += self.build_op('LOAD_CONST', self.get_const('')) - nops += self.build_op('DUP_TOP') + nops += self.build_op("LOAD_CONST", self.get_const("")) + nops += self.build_op("DUP_TOP") elif needs_build_string: # we inserted the ''.join reference at the bottom of the # stack, but we don't want to call it: throw it away - nops += self.build_op('ROT_TWO') - nops += self.build_op('POP_TOP') - nops += self.build_op('JUMP_FORWARD', len(uops)) + nops += self.build_op("ROT_TWO") + nops += self.build_op("POP_TOP") + nops += self.build_op("JUMP_FORWARD", len(uops)) # this jump needs to take its own length into account. the # simple way to do that is to compute a minimal guess for the # length of the jump instruction, and keep revising it upward - jump_op = self.build_op('JUMP_IF_TRUE_OR_POP', 0) + jump_op = self.build_op("JUMP_IF_TRUE_OR_POP", 0) while True: jump_len = len(jump_op) jump_target = ind + len(ops) + jump_len + len(nops) - jump_op = self.build_op('JUMP_IF_TRUE_OR_POP', jump_target) + jump_op = self.build_op("JUMP_IF_TRUE_OR_POP", jump_target) assert len(jump_op) >= jump_len if len(jump_op) == jump_len: break @@ -1005,8 +1024,7 @@ class Rule(RuleFactory): n += 2 peak_stack = max(peak_stack, stack + 2) elif needs_build_string: - ops = self.build_op( - 'LOAD_CONST', self.get_const(self.JOIN_EMPTY)) + ops + ops = self.build_op("LOAD_CONST", self.get_const(self.JOIN_EMPTY)) + ops peak_stack += 1 if not dont_build_string: ops += self.build_string(n) @@ -1022,35 +1040,41 @@ class Rule(RuleFactory): url_encode, charset=self.rule.map.charset, sort=self.rule.map.sort_parameters, - key=self.rule.map.sort_key) + key=self.rule.map.sort_key, + ) for is_dynamic, data in self.rule._trace: - if data == '|' and opl is dom_ops: + if data == "|" and opl is dom_ops: opl = url_ops continue # this seems like a silly case to ever come up but: # if a default is given for a value that appears in the rule, # resolve it to a constant ahead of time if is_dynamic and data in self.defaults: - data = self.rule._converters[data].to_url( - self.defaults[data]) + data = self.rule._converters[data].to_url(self.defaults[data]) is_dynamic = False if not is_dynamic: - opl.append((None, url_quote( - to_bytes(data, self.rule.map.charset), safe='/:|+'))) + opl.append( + ( + None, + url_quote( + to_bytes(data, self.rule.map.charset), safe="/:|+" + ), + ) + ) continue opl.append((self.rule._converters[data].to_url, data)) dom_ops = self.collapse_constants(dom_ops) url_ops = self.collapse_constants(url_ops) - for op, elem in (dom_ops + url_ops): + for op, elem in dom_ops + url_ops: if op is not None: self.get_var(elem) self.add_defaults() argcount = len(self.var) # invalid name for paranoia reasons - self.get_var('.keyword_arguments') + self.get_var(".keyword_arguments") stack = 0 peak_stack = 0 - ops = b'' + ops = b"" if ( not append_unknown and len(dom_ops) == len(url_ops) == 1 @@ -1059,8 +1083,7 @@ class Rule(RuleFactory): # shortcut: just return the constant stack = peak_stack = 1 constant_value = (dom_ops[0][1], url_ops[0][1]) - ops += self.build_op( - 'LOAD_CONST', self.get_const(constant_value)) + ops += self.build_op("LOAD_CONST", self.get_const(constant_value)) else: ps, rv = self.emit_build(len(ops), dom_ops) ops += rv @@ -1068,14 +1091,14 @@ class Rule(RuleFactory): stack += 1 if append_unknown: ps, rv = self.emit_build( - len(ops), url_ops, append_unknown, encode_query_vars, - argcount) + len(ops), url_ops, append_unknown, encode_query_vars, argcount + ) else: ps, rv = self.emit_build(len(ops), url_ops) ops += rv peak_stack = max(stack + ps, peak_stack) - ops += self.build_op('BUILD_TUPLE', 2) - ops += self.build_op('RETURN_VALUE') + ops += self.build_op("BUILD_TUPLE", 2) + ops += self.build_op("RETURN_VALUE") code_args = [ argcount, len(self.var), @@ -1085,10 +1108,10 @@ class Rule(RuleFactory): tuple(self.consts), (), tuple(self.var), - 'generated', - '<builder:%r>' % self.rule.rule, + "generated", + "<builder:%r>" % self.rule.rule, 1, - b'' + b"", ] if sys.version_info >= (3,): code_args[1:1] = [0] @@ -1124,9 +1147,13 @@ class Rule(RuleFactory): :internal: """ - return not self.build_only and self.defaults and \ - self.endpoint == rule.endpoint and self != rule and \ - self.arguments == rule.arguments + return ( + not self.build_only + and self.defaults + and self.endpoint == rule.endpoint + and self != rule + and self.arguments == rule.arguments + ) def suitable_for(self, values, method=None): """Check if the dict of values has enough data for url generation. @@ -1135,8 +1162,11 @@ class Rule(RuleFactory): """ # if a method was given explicitly and that method is not supported # by this rule, this rule is not suitable. - if method is not None and self.methods is not None \ - and method not in self.methods: + if ( + method is not None + and self.methods is not None + and method not in self.methods + ): return False defaults = self.defaults or () @@ -1174,20 +1204,23 @@ class Rule(RuleFactory): :internal: """ - return bool(self.arguments), -len(self._static_weights), self._static_weights,\ - -len(self._argument_weights), self._argument_weights + return ( + bool(self.arguments), + -len(self._static_weights), + self._static_weights, + -len(self._argument_weights), + self._argument_weights, + ) def build_compare_key(self): """The build compare key for sorting. :internal: """ - return 1 if self.alias else 0, -len(self.arguments), \ - -len(self.defaults or ()) + return 1 if self.alias else 0, -len(self.arguments), -len(self.defaults or ()) def __eq__(self, other): - return self.__class__ is other.__class__ and \ - self._trace == other._trace + return self.__class__ is other.__class__ and self._trace == other._trace __hash__ = None @@ -1200,27 +1233,25 @@ class Rule(RuleFactory): @native_string_result def __repr__(self): if self.map is None: - return u'<%s (unbound)>' % self.__class__.__name__ + return u"<%s (unbound)>" % self.__class__.__name__ tmp = [] for is_dynamic, data in self._trace: if is_dynamic: - tmp.append(u'<%s>' % data) + tmp.append(u"<%s>" % data) else: tmp.append(data) - return u'<%s %s%s -> %s>' % ( + return u"<%s %s%s -> %s>" % ( self.__class__.__name__, - repr((u''.join(tmp)).lstrip(u'|')).lstrip(u'u'), - self.methods is not None - and u' (%s)' % u', '.join(self.methods) - or u'', - self.endpoint + repr((u"".join(tmp)).lstrip(u"|")).lstrip(u"u"), + self.methods is not None and u" (%s)" % u", ".join(self.methods) or u"", + self.endpoint, ) class BaseConverter(object): - """Base class for all converters.""" - regex = '[^/]+' + + regex = "[^/]+" weight = 100 def __init__(self, map): @@ -1234,7 +1265,6 @@ class BaseConverter(object): class UnicodeConverter(BaseConverter): - """This converter is the default converter and accepts any string but only one path segment. Thus the string can not include a slash. @@ -1255,21 +1285,17 @@ class UnicodeConverter(BaseConverter): def __init__(self, map, minlength=1, maxlength=None, length=None): BaseConverter.__init__(self, map) if length is not None: - length = '{%d}' % int(length) + length = "{%d}" % int(length) else: if maxlength is None: - maxlength = '' + maxlength = "" else: maxlength = int(maxlength) - length = '{%s,%s}' % ( - int(minlength), - maxlength - ) - self.regex = '[^/]' + length + length = "{%s,%s}" % (int(minlength), maxlength) + self.regex = "[^/]" + length class AnyConverter(BaseConverter): - """Matches one of the items provided. Items can either be Python identifiers or strings:: @@ -1282,11 +1308,10 @@ class AnyConverter(BaseConverter): def __init__(self, map, *items): BaseConverter.__init__(self, map) - self.regex = '(?:%s)' % '|'.join([re.escape(x) for x in items]) + self.regex = "(?:%s)" % "|".join([re.escape(x) for x in items]) class PathConverter(BaseConverter): - """Like the default :class:`UnicodeConverter`, but it also matches slashes. This is useful for wikis and similar applications:: @@ -1295,16 +1320,17 @@ class PathConverter(BaseConverter): :param map: the :class:`Map`. """ - regex = '[^/].*?' + + regex = "[^/].*?" weight = 200 class NumberConverter(BaseConverter): - """Baseclass for `IntegerConverter` and `FloatConverter`. :internal: """ + weight = 50 def __init__(self, map, fixed_digits=0, min=None, max=None, signed=False): @@ -1317,18 +1343,19 @@ class NumberConverter(BaseConverter): self.signed = signed def to_python(self, value): - if (self.fixed_digits and len(value) != self.fixed_digits): + if self.fixed_digits and len(value) != self.fixed_digits: raise ValidationError() value = self.num_convert(value) - if (self.min is not None and value < self.min) or \ - (self.max is not None and value > self.max): + if (self.min is not None and value < self.min) or ( + self.max is not None and value > self.max + ): raise ValidationError() return value def to_url(self, value): value = self.num_convert(value) if self.fixed_digits: - value = ('%%0%sd' % self.fixed_digits) % value + value = ("%%0%sd" % self.fixed_digits) % value return str(value) @property @@ -1337,7 +1364,6 @@ class NumberConverter(BaseConverter): class IntegerConverter(NumberConverter): - """This converter only accepts integer values:: Rule("/page/<int:page>") @@ -1358,12 +1384,12 @@ class IntegerConverter(NumberConverter): .. versionadded:: 0.15 The ``signed`` parameter. """ - regex = r'\d+' + + regex = r"\d+" num_convert = int class FloatConverter(NumberConverter): - """This converter only accepts floating point values:: Rule("/probability/<float:probability>") @@ -1381,7 +1407,8 @@ class FloatConverter(NumberConverter): .. versionadded:: 0.15 The ``signed`` parameter. """ - regex = r'\d+\.\d+' + + regex = r"\d+\.\d+" num_convert = float def __init__(self, map, min=None, max=None, signed=False): @@ -1389,7 +1416,6 @@ class FloatConverter(NumberConverter): class UUIDConverter(BaseConverter): - """This converter only accepts UUID strings:: Rule('/object/<uuid:identifier>') @@ -1398,8 +1424,11 @@ class UUIDConverter(BaseConverter): :param map: the :class:`Map`. """ - regex = r'[A-Fa-f0-9]{8}-[A-Fa-f0-9]{4}-' \ - r'[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{12}' + + regex = ( + r"[A-Fa-f0-9]{8}-[A-Fa-f0-9]{4}-" + r"[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{12}" + ) def to_python(self, value): return uuid.UUID(value) @@ -1410,18 +1439,17 @@ class UUIDConverter(BaseConverter): #: the default converter mapping for the map. DEFAULT_CONVERTERS = { - 'default': UnicodeConverter, - 'string': UnicodeConverter, - 'any': AnyConverter, - 'path': PathConverter, - 'int': IntegerConverter, - 'float': FloatConverter, - 'uuid': UUIDConverter, + "default": UnicodeConverter, + "string": UnicodeConverter, + "any": AnyConverter, + "path": PathConverter, + "int": IntegerConverter, + "float": FloatConverter, + "uuid": UUIDConverter, } class Map(object): - """The map class stores all the URL rules and some configuration parameters. Some of the configuration values are only stored on the `Map` instance since those affect all rules, others are just defaults @@ -1458,10 +1486,19 @@ class Map(object): #: A dict of default converters to be used. default_converters = ImmutableDict(DEFAULT_CONVERTERS) - def __init__(self, rules=None, default_subdomain='', charset='utf-8', - strict_slashes=True, redirect_defaults=True, - converters=None, sort_parameters=False, sort_key=None, - encoding_errors='replace', host_matching=False): + def __init__( + self, + rules=None, + default_subdomain="", + charset="utf-8", + strict_slashes=True, + redirect_defaults=True, + converters=None, + sort_parameters=False, + sort_key=None, + encoding_errors="replace", + host_matching=False, + ): self._rules = [] self._rules_by_endpoint = {} self._remap = True @@ -1528,9 +1565,16 @@ class Map(object): self._rules_by_endpoint.setdefault(rule.endpoint, []).append(rule) self._remap = True - def bind(self, server_name, script_name=None, subdomain=None, - url_scheme='http', default_method='GET', path_info=None, - query_args=None): + def bind( + self, + server_name, + script_name=None, + subdomain=None, + url_scheme="http", + default_method="GET", + path_info=None, + query_args=None, + ): """Return a new :class:`MapAdapter` with the details specified to the call. Note that `script_name` will default to ``'/'`` if not further specified or `None`. The `server_name` at least is a requirement @@ -1559,20 +1603,27 @@ class Map(object): server_name = server_name.lower() if self.host_matching: if subdomain is not None: - raise RuntimeError('host matching enabled and a ' - 'subdomain was provided') + raise RuntimeError("host matching enabled and a subdomain was provided") elif subdomain is None: subdomain = self.default_subdomain if script_name is None: - script_name = '/' + script_name = "/" if path_info is None: - path_info = '/' + path_info = "/" try: server_name = _encode_idna(server_name) except UnicodeError: raise BadHost() - return MapAdapter(self, server_name, script_name, subdomain, - url_scheme, path_info, default_method, query_args) + return MapAdapter( + self, + server_name, + script_name, + subdomain, + url_scheme, + path_info, + default_method, + query_args, + ) def bind_to_environ(self, environ, server_name=None, subdomain=None): """Like :meth:`bind` but you can pass it an WSGI environment and it @@ -1618,8 +1669,8 @@ class Map(object): server_name = server_name.lower() if subdomain is None and not self.host_matching: - cur_server_name = wsgi_server_name.split('.') - real_server_name = server_name.split('.') + cur_server_name = wsgi_server_name.split(".") + real_server_name = server_name.split(".") offset = -len(real_server_name) if cur_server_name[offset:] != real_server_name: # This can happen even with valid configs if the server was @@ -1627,22 +1678,28 @@ class Map(object): # Instead of raising an exception like in Werkzeug 0.7 or # earlier we go by an invalid subdomain which will result # in a 404 error on matching. - subdomain = '<invalid>' + subdomain = "<invalid>" else: - subdomain = '.'.join(filter(None, cur_server_name[:offset])) + subdomain = ".".join(filter(None, cur_server_name[:offset])) def _get_wsgi_string(name): val = environ.get(name) if val is not None: return wsgi_decoding_dance(val, self.charset) - script_name = _get_wsgi_string('SCRIPT_NAME') - path_info = _get_wsgi_string('PATH_INFO') - query_args = _get_wsgi_string('QUERY_STRING') - return Map.bind(self, server_name, script_name, - subdomain, environ['wsgi.url_scheme'], - environ['REQUEST_METHOD'], path_info, - query_args=query_args) + script_name = _get_wsgi_string("SCRIPT_NAME") + path_info = _get_wsgi_string("PATH_INFO") + query_args = _get_wsgi_string("QUERY_STRING") + return Map.bind( + self, + server_name, + script_name, + subdomain, + environ["wsgi.url_scheme"], + environ["REQUEST_METHOD"], + path_info, + query_args=query_args, + ) def update(self): """Called before matching and building to keep the compiled rules @@ -1662,7 +1719,7 @@ class Map(object): def __repr__(self): rules = self.iter_rules() - return '%s(%s)' % (self.__class__.__name__, pformat(list(rules))) + return "%s(%s)" % (self.__class__.__name__, pformat(list(rules))) class MapAdapter(object): @@ -1671,13 +1728,22 @@ class MapAdapter(object): the URL matching and building based on runtime information. """ - def __init__(self, map, server_name, script_name, subdomain, - url_scheme, path_info, default_method, query_args=None): + def __init__( + self, + map, + server_name, + script_name, + subdomain, + url_scheme, + path_info, + default_method, + query_args=None, + ): self.map = map self.server_name = to_unicode(server_name) script_name = to_unicode(script_name) - if not script_name.endswith(u'/'): - script_name += u'/' + if not script_name.endswith(u"/"): + script_name += u"/" self.script_name = script_name self.subdomain = to_unicode(subdomain) self.url_scheme = to_unicode(url_scheme) @@ -1685,8 +1751,9 @@ class MapAdapter(object): self.default_method = to_unicode(default_method) self.query_args = query_args - def dispatch(self, view_func, path_info=None, method=None, - catch_http_exceptions=False): + def dispatch( + self, view_func, path_info=None, method=None, catch_http_exceptions=False + ): """Does the complete dispatching process. `view_func` is called with the endpoint and a dict with the values for the view. It should look up the view function, call it, and return a response object @@ -1740,8 +1807,7 @@ class MapAdapter(object): return e raise - def match(self, path_info=None, method=None, return_rule=False, - query_args=None): + def match(self, path_info=None, method=None, return_rule=False, query_args=None): """The usage is simple: you just pass the match method the current path info as well as the method (which defaults to `GET`). The following things can then happen: @@ -1827,9 +1893,9 @@ class MapAdapter(object): query_args = self.query_args method = (method or self.default_method).upper() - path = u'%s|%s' % ( + path = u"%s|%s" % ( self.map.host_matching and self.server_name or self.subdomain, - path_info and '/%s' % path_info.lstrip('/') + path_info and "/%s" % path_info.lstrip("/"), ) have_match_for = set() @@ -1837,12 +1903,18 @@ class MapAdapter(object): try: rv = rule.match(path, method) except RequestSlash: - raise RequestRedirect(self.make_redirect_url( - url_quote(path_info, self.map.charset, - safe='/:|+') + '/', query_args)) + raise RequestRedirect( + self.make_redirect_url( + url_quote(path_info, self.map.charset, safe="/:|+") + "/", + query_args, + ) + ) except RequestAliasRedirect as e: - raise RequestRedirect(self.make_alias_redirect_url( - path, rule.endpoint, e.matched_values, method, query_args)) + raise RequestRedirect( + self.make_alias_redirect_url( + path, rule.endpoint, e.matched_values, method, query_args + ) + ) if rv is None: continue if rule.methods is not None and method not in rule.methods: @@ -1850,26 +1922,34 @@ class MapAdapter(object): continue if self.map.redirect_defaults: - redirect_url = self.get_default_redirect(rule, method, rv, - query_args) + redirect_url = self.get_default_redirect(rule, method, rv, query_args) if redirect_url is not None: raise RequestRedirect(redirect_url) if rule.redirect_to is not None: if isinstance(rule.redirect_to, string_types): + def _handle_match(match): value = rv[match.group(1)] return rule._converters[match.group(1)].to_url(value) - redirect_url = _simple_rule_re.sub(_handle_match, - rule.redirect_to) + + redirect_url = _simple_rule_re.sub(_handle_match, rule.redirect_to) else: redirect_url = rule.redirect_to(self, **rv) - raise RequestRedirect(str(url_join('%s://%s%s%s' % ( - self.url_scheme or 'http', - self.subdomain + '.' if self.subdomain else '', - self.server_name, - self.script_name - ), redirect_url))) + raise RequestRedirect( + str( + url_join( + "%s://%s%s%s" + % ( + self.url_scheme or "http", + self.subdomain + "." if self.subdomain else "", + self.server_name, + self.script_name, + ), + redirect_url, + ) + ) + ) if return_rule: return rule, rv @@ -1903,7 +1983,7 @@ class MapAdapter(object): .. versionadded:: 0.7 """ try: - self.match(path_info, method='--') + self.match(path_info, method="--") except MethodNotAllowed as e: return e.valid_methods except HTTPException: @@ -1918,13 +1998,13 @@ class MapAdapter(object): if self.map.host_matching: if domain_part is None: return self.server_name - return to_unicode(domain_part, 'ascii') + return to_unicode(domain_part, "ascii") subdomain = domain_part if subdomain is None: subdomain = self.subdomain else: - subdomain = to_unicode(subdomain, 'ascii') - return (subdomain + u'.' if subdomain else u'') + self.server_name + subdomain = to_unicode(subdomain, "ascii") + return (subdomain + u"." if subdomain else u"") + self.server_name def get_default_redirect(self, rule, method, values, query_args): """A helper that returns the URL to redirect to if it finds one. @@ -1939,12 +2019,10 @@ class MapAdapter(object): # with the highest priority up for building. if r is rule: break - if r.provides_defaults_for(rule) and \ - r.suitable_for(values, method): + if r.provides_defaults_for(rule) and r.suitable_for(values, method): values.update(r.defaults) domain_part, path = r.build(values) - return self.make_redirect_url( - path, query_args, domain_part=domain_part) + return self.make_redirect_url(path, query_args, domain_part=domain_part) def encode_query_args(self, query_args): if not isinstance(query_args, string_types): @@ -1956,25 +2034,29 @@ class MapAdapter(object): :internal: """ - suffix = '' + suffix = "" if query_args: - suffix = '?' + self.encode_query_args(query_args) - return str('%s://%s/%s%s' % ( - self.url_scheme or 'http', - self.get_host(domain_part), - posixpath.join(self.script_name[:-1].lstrip('/'), - path_info.lstrip('/')), - suffix - )) + suffix = "?" + self.encode_query_args(query_args) + return str( + "%s://%s/%s%s" + % ( + self.url_scheme or "http", + self.get_host(domain_part), + posixpath.join( + self.script_name[:-1].lstrip("/"), path_info.lstrip("/") + ), + suffix, + ) + ) def make_alias_redirect_url(self, path, endpoint, values, method, query_args): """Internally called to make an alias redirect URL.""" - url = self.build(endpoint, values, method, append_unknown=False, - force_external=True) + url = self.build( + endpoint, values, method, append_unknown=False, force_external=True + ) if query_args: - url += '?' + self.encode_query_args(query_args) - assert url != path, 'detected invalid alias setting. No canonical ' \ - 'URL found' + url += "?" + self.encode_query_args(query_args) + assert url != path, "detected invalid alias setting. No canonical URL found" return url def _partial_build(self, endpoint, values, method, append_unknown): @@ -1985,8 +2067,9 @@ class MapAdapter(object): """ # in case the method is none, try with the default method first if method is None: - rv = self._partial_build(endpoint, values, self.default_method, - append_unknown) + rv = self._partial_build( + endpoint, values, self.default_method, append_unknown + ) if rv is not None: return rv @@ -1998,8 +2081,14 @@ class MapAdapter(object): if rv is not None: return rv - def build(self, endpoint, values=None, method=None, force_external=False, - append_unknown=True): + def build( + self, + endpoint, + values=None, + method=None, + force_external=False, + append_unknown=True, + ): """Building URLs works pretty much the other way round. Instead of `match` you call `build` and pass it the endpoint and a dict of arguments for the placeholders. @@ -2100,10 +2189,13 @@ class MapAdapter(object): (self.map.host_matching and host == self.server_name) or (not self.map.host_matching and domain_part == self.subdomain) ): - return '%s/%s' % (self.script_name.rstrip('/'), path.lstrip('/')) - return str('%s//%s%s/%s' % ( - self.url_scheme + ':' if self.url_scheme else '', - host, - self.script_name[:-1], - path.lstrip('/') - )) + return "%s/%s" % (self.script_name.rstrip("/"), path.lstrip("/")) + return str( + "%s//%s%s/%s" + % ( + self.url_scheme + ":" if self.url_scheme else "", + host, + self.script_name[:-1], + path.lstrip("/"), + ) + ) diff --git a/src/werkzeug/security.py b/src/werkzeug/security.py index 6d2b8de0..1842afd0 100644 --- a/src/werkzeug/security.py +++ b/src/werkzeug/security.py @@ -8,31 +8,35 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import os -import hmac +import codecs import hashlib +import hmac +import os import posixpath -import codecs -from struct import Struct from random import SystemRandom +from struct import Struct -from werkzeug._compat import range_type, PY2, text_type, izip, to_bytes, \ - to_native - +from ._compat import izip +from ._compat import PY2 +from ._compat import range_type +from ._compat import text_type +from ._compat import to_bytes +from ._compat import to_native -SALT_CHARS = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789' +SALT_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" DEFAULT_PBKDF2_ITERATIONS = 150000 - -_pack_int = Struct('>I').pack -_builtin_safe_str_cmp = getattr(hmac, 'compare_digest', None) +_pack_int = Struct(">I").pack +_builtin_safe_str_cmp = getattr(hmac, "compare_digest", None) _sys_rng = SystemRandom() -_os_alt_seps = list(sep for sep in [os.path.sep, os.path.altsep] - if sep not in (None, '/')) +_os_alt_seps = list( + sep for sep in [os.path.sep, os.path.altsep] if sep not in (None, "/") +) -def pbkdf2_hex(data, salt, iterations=DEFAULT_PBKDF2_ITERATIONS, - keylen=None, hashfunc=None): +def pbkdf2_hex( + data, salt, iterations=DEFAULT_PBKDF2_ITERATIONS, keylen=None, hashfunc=None +): """Like :func:`pbkdf2_bin`, but returns a hex-encoded string. .. versionadded:: 0.9 @@ -47,11 +51,12 @@ def pbkdf2_hex(data, salt, iterations=DEFAULT_PBKDF2_ITERATIONS, from the hashlib module. Defaults to sha256. """ rv = pbkdf2_bin(data, salt, iterations, keylen, hashfunc) - return to_native(codecs.encode(rv, 'hex_codec')) + return to_native(codecs.encode(rv, "hex_codec")) -def pbkdf2_bin(data, salt, iterations=DEFAULT_PBKDF2_ITERATIONS, - keylen=None, hashfunc=None): +def pbkdf2_bin( + data, salt, iterations=DEFAULT_PBKDF2_ITERATIONS, keylen=None, hashfunc=None +): """Returns a binary digest for the PBKDF2 hash algorithm of `data` with the given `salt`. It iterates `iterations` times and produces a key of `keylen` bytes. By default, SHA-256 is used as hash function; @@ -69,14 +74,14 @@ def pbkdf2_bin(data, salt, iterations=DEFAULT_PBKDF2_ITERATIONS, from the hashlib module. Defaults to sha256. """ if not hashfunc: - hashfunc = 'sha256' + hashfunc = "sha256" data = to_bytes(data) salt = to_bytes(salt) if callable(hashfunc): _test_hash = hashfunc() - hash_name = getattr(_test_hash, 'name', None) + hash_name = getattr(_test_hash, "name", None) else: hash_name = hashfunc return hashlib.pbkdf2_hmac(hash_name, data, salt, iterations, keylen) @@ -91,9 +96,9 @@ def safe_str_cmp(a, b): .. versionadded:: 0.7 """ if isinstance(a, text_type): - a = a.encode('utf-8') + a = a.encode("utf-8") if isinstance(b, text_type): - b = b.encode('utf-8') + b = b.encode("utf-8") if _builtin_safe_str_cmp is not None: return _builtin_safe_str_cmp(a, b) @@ -115,8 +120,8 @@ def safe_str_cmp(a, b): def gen_salt(length): """Generate a random string of SALT_CHARS with specified ``length``.""" if length <= 0: - raise ValueError('Salt length must be positive') - return ''.join(_sys_rng.choice(SALT_CHARS) for _ in range_type(length)) + raise ValueError("Salt length must be positive") + return "".join(_sys_rng.choice(SALT_CHARS) for _ in range_type(length)) def _hash_internal(method, salt, password): @@ -124,31 +129,31 @@ def _hash_internal(method, salt, password): unsalted and salted passwords. In case salted passwords are used hmac is used. """ - if method == 'plain': + if method == "plain": return password, method if isinstance(password, text_type): - password = password.encode('utf-8') + password = password.encode("utf-8") - if method.startswith('pbkdf2:'): - args = method[7:].split(':') + if method.startswith("pbkdf2:"): + args = method[7:].split(":") if len(args) not in (1, 2): - raise ValueError('Invalid number of arguments for PBKDF2') + raise ValueError("Invalid number of arguments for PBKDF2") method = args.pop(0) iterations = args and int(args[0] or 0) or DEFAULT_PBKDF2_ITERATIONS is_pbkdf2 = True - actual_method = 'pbkdf2:%s:%d' % (method, iterations) + actual_method = "pbkdf2:%s:%d" % (method, iterations) else: is_pbkdf2 = False actual_method = method if is_pbkdf2: if not salt: - raise ValueError('Salt is required for PBKDF2') + raise ValueError("Salt is required for PBKDF2") rv = pbkdf2_hex(password, salt, iterations, hashfunc=method) elif salt: if isinstance(salt, text_type): - salt = salt.encode('utf-8') + salt = salt.encode("utf-8") mac = _create_mac(salt, password, method) rv = mac.hexdigest() else: @@ -159,14 +164,17 @@ def _hash_internal(method, salt, password): def _create_mac(key, msg, method): if callable(method): return hmac.HMAC(key, msg, method) - hashfunc = lambda d=b'': hashlib.new(method, d) + + def hashfunc(d=b""): + return hashlib.new(method, d) + # Python 2.7 used ``hasattr(digestmod, '__call__')`` # to detect if hashfunc is callable hashfunc.__call__ = hashfunc return hmac.HMAC(key, msg, hashfunc) -def generate_password_hash(password, method='pbkdf2:sha256', salt_length=8): +def generate_password_hash(password, method="pbkdf2:sha256", salt_length=8): """Hash a password with the given method and salt with a string of the given length. The format of the string returned includes the method that was used so that :func:`check_password_hash` can check the hash. @@ -191,9 +199,9 @@ def generate_password_hash(password, method='pbkdf2:sha256', salt_length=8): to enable PBKDF2. :param salt_length: the length of the salt in letters. """ - salt = gen_salt(salt_length) if method != 'plain' else '' + salt = gen_salt(salt_length) if method != "plain" else "" h, actual_method = _hash_internal(method, salt, password) - return '%s$%s$%s' % (actual_method, salt, h) + return "%s$%s$%s" % (actual_method, salt, h) def check_password_hash(pwhash, password): @@ -207,9 +215,9 @@ def check_password_hash(pwhash, password): :func:`generate_password_hash`. :param password: the plaintext password to compare against the hash. """ - if pwhash.count('$') < 2: + if pwhash.count("$") < 2: return False - method, salt, hashval = pwhash.split('$', 2) + method, salt, hashval = pwhash.split("$", 2) return safe_str_cmp(_hash_internal(method, salt, password)[0], hashval) @@ -222,14 +230,12 @@ def safe_join(directory, *pathnames): """ parts = [directory] for filename in pathnames: - if filename != '': + if filename != "": filename = posixpath.normpath(filename) for sep in _os_alt_seps: if sep in filename: return None - if os.path.isabs(filename) or \ - filename == '..' or \ - filename.startswith('../'): + if os.path.isabs(filename) or filename == ".." or filename.startswith("../"): return None parts.append(filename) return posixpath.join(*parts) diff --git a/src/werkzeug/serving.py b/src/werkzeug/serving.py index 363bb029..4a179b3e 100644 --- a/src/werkzeug/serving.py +++ b/src/werkzeug/serving.py @@ -37,72 +37,74 @@ """ import io import os +import signal import socket import sys -import signal - - -can_fork = hasattr(os, "fork") +import werkzeug +from ._compat import PY2 +from ._compat import reraise +from ._compat import WIN +from ._compat import wsgi_encoding_dance +from ._internal import _log +from .exceptions import InternalServerError +from .urls import uri_to_iri +from .urls import url_parse +from .urls import url_unquote try: - import termcolor + import socketserver + from http.server import BaseHTTPRequestHandler + from http.server import HTTPServer except ImportError: - termcolor = None + import SocketServer as socketserver + from BaseHTTPServer import HTTPServer + from BaseHTTPServer import BaseHTTPRequestHandler try: import ssl except ImportError: + class _SslDummy(object): def __getattr__(self, name): - raise RuntimeError('SSL support unavailable') + raise RuntimeError("SSL support unavailable") + ssl = _SslDummy() +try: + import termcolor +except ImportError: + termcolor = None + def _get_openssl_crypto_module(): try: from OpenSSL import crypto except ImportError: - raise TypeError('Using ad-hoc certificates requires the pyOpenSSL ' - 'library.') + raise TypeError("Using ad-hoc certificates requires the pyOpenSSL library.") else: return crypto -try: - import SocketServer as socketserver - from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler -except ImportError: - import socketserver - from http.server import HTTPServer, BaseHTTPRequestHandler - ThreadingMixIn = socketserver.ThreadingMixIn +can_fork = hasattr(os, "fork") if can_fork: ForkingMixIn = socketserver.ForkingMixIn else: + class ForkingMixIn(object): pass + try: af_unix = socket.AF_UNIX except AttributeError: af_unix = None -import werkzeug -from werkzeug._internal import _log -from werkzeug._compat import PY2, WIN, reraise, wsgi_encoding_dance -from werkzeug.urls import ( - url_parse, - url_unquote, - uri_to_iri, -) -from werkzeug.exceptions import InternalServerError - - LISTEN_QUEUE = 128 -can_open_by_fd = not WIN and hasattr(socket, 'fromfd') +can_open_by_fd = not WIN and hasattr(socket, "fromfd") # On Python 3, ConnectionError represents the same errnos as # socket.error from Python 2, while socket.error is an alias for the @@ -126,12 +128,12 @@ class DechunkedInput(io.RawIOBase): def read_chunk_len(self): try: - line = self._rfile.readline().decode('latin1') + line = self._rfile.readline().decode("latin1") _len = int(line.strip(), 16) except ValueError: - raise IOError('Invalid chunk header') + raise IOError("Invalid chunk header") if _len < 0: - raise IOError('Negative chunk length not allowed') + raise IOError("Negative chunk length not allowed") return _len def readinto(self, buf): @@ -152,7 +154,7 @@ class DechunkedInput(io.RawIOBase): # buffer. If this operation fully consumes the chunk, this will # reset self._len to 0. n = min(len(buf), self._len) - buf[read:read + n] = self._rfile.read(n) + buf[read : read + n] = self._rfile.read(n) self._len -= n read += n @@ -160,8 +162,8 @@ class DechunkedInput(io.RawIOBase): # Skip the terminating newline of a chunk that has been fully # consumed. This also applies to the 0-sized final chunk terminator = self._rfile.readline() - if terminator not in (b'\n', b'\r\n', b'\r'): - raise IOError('Missing chunk terminating newline') + if terminator not in (b"\n", b"\r\n", b"\r"): + raise IOError("Missing chunk terminating newline") return read @@ -172,7 +174,7 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): @property def server_version(self): - return 'Werkzeug/' + werkzeug.__version__ + return "Werkzeug/" + werkzeug.__version__ def make_environ(self): request_url = url_parse(self.path) @@ -180,9 +182,9 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): def shutdown_server(): self.server.shutdown_signal = True - url_scheme = 'http' if self.server.ssl_context is None else 'https' + url_scheme = "http" if self.server.ssl_context is None else "https" if not self.client_address: - self.client_address = '<local>' + self.client_address = "<local>" if isinstance(self.client_address, str): self.client_address = (self.client_address, 0) else: @@ -190,57 +192,57 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): path_info = url_unquote(request_url.path) environ = { - 'wsgi.version': (1, 0), - 'wsgi.url_scheme': url_scheme, - 'wsgi.input': self.rfile, - 'wsgi.errors': sys.stderr, - 'wsgi.multithread': self.server.multithread, - 'wsgi.multiprocess': self.server.multiprocess, - 'wsgi.run_once': False, - 'werkzeug.server.shutdown': shutdown_server, - 'SERVER_SOFTWARE': self.server_version, - 'REQUEST_METHOD': self.command, - 'SCRIPT_NAME': '', - 'PATH_INFO': wsgi_encoding_dance(path_info), - 'QUERY_STRING': wsgi_encoding_dance(request_url.query), + "wsgi.version": (1, 0), + "wsgi.url_scheme": url_scheme, + "wsgi.input": self.rfile, + "wsgi.errors": sys.stderr, + "wsgi.multithread": self.server.multithread, + "wsgi.multiprocess": self.server.multiprocess, + "wsgi.run_once": False, + "werkzeug.server.shutdown": shutdown_server, + "SERVER_SOFTWARE": self.server_version, + "REQUEST_METHOD": self.command, + "SCRIPT_NAME": "", + "PATH_INFO": wsgi_encoding_dance(path_info), + "QUERY_STRING": wsgi_encoding_dance(request_url.query), # Non-standard, added by mod_wsgi, uWSGI "REQUEST_URI": wsgi_encoding_dance(self.path), # Non-standard, added by gunicorn "RAW_URI": wsgi_encoding_dance(self.path), - 'REMOTE_ADDR': self.address_string(), - 'REMOTE_PORT': self.port_integer(), - 'SERVER_NAME': self.server.server_address[0], - 'SERVER_PORT': str(self.server.server_address[1]), - 'SERVER_PROTOCOL': self.request_version + "REMOTE_ADDR": self.address_string(), + "REMOTE_PORT": self.port_integer(), + "SERVER_NAME": self.server.server_address[0], + "SERVER_PORT": str(self.server.server_address[1]), + "SERVER_PROTOCOL": self.request_version, } for key, value in self.get_header_items(): - key = key.upper().replace('-', '_') - if key not in ('CONTENT_TYPE', 'CONTENT_LENGTH'): - key = 'HTTP_' + key + key = key.upper().replace("-", "_") + if key not in ("CONTENT_TYPE", "CONTENT_LENGTH"): + key = "HTTP_" + key if key in environ: value = "{},{}".format(environ[key], value) environ[key] = value - if environ.get('HTTP_TRANSFER_ENCODING', '').strip().lower() == 'chunked': - environ['wsgi.input_terminated'] = True - environ['wsgi.input'] = DechunkedInput(environ['wsgi.input']) + if environ.get("HTTP_TRANSFER_ENCODING", "").strip().lower() == "chunked": + environ["wsgi.input_terminated"] = True + environ["wsgi.input"] = DechunkedInput(environ["wsgi.input"]) if request_url.scheme and request_url.netloc: - environ['HTTP_HOST'] = request_url.netloc + environ["HTTP_HOST"] = request_url.netloc return environ def run_wsgi(self): - if self.headers.get('Expect', '').lower().strip() == '100-continue': - self.wfile.write(b'HTTP/1.1 100 Continue\r\n\r\n') + if self.headers.get("Expect", "").lower().strip() == "100-continue": + self.wfile.write(b"HTTP/1.1 100 Continue\r\n\r\n") self.environ = environ = self.make_environ() headers_set = [] headers_sent = [] def write(data): - assert headers_set, 'write() before start_response' + assert headers_set, "write() before start_response" if not headers_sent: status, response_headers = headers_sent[:] = headers_set try: @@ -254,18 +256,21 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): self.send_header(key, value) key = key.lower() header_keys.add(key) - if not ('content-length' in header_keys - or environ['REQUEST_METHOD'] == 'HEAD' - or code < 200 or code in (204, 304)): + if not ( + "content-length" in header_keys + or environ["REQUEST_METHOD"] == "HEAD" + or code < 200 + or code in (204, 304) + ): self.close_connection = True - self.send_header('Connection', 'close') - if 'server' not in header_keys: - self.send_header('Server', self.version_string()) - if 'date' not in header_keys: - self.send_header('Date', self.date_time_string()) + self.send_header("Connection", "close") + if "server" not in header_keys: + self.send_header("Server", self.version_string()) + if "date" not in header_keys: + self.send_header("Date", self.date_time_string()) self.end_headers() - assert isinstance(data, bytes), 'applications must write bytes' + assert isinstance(data, bytes), "applications must write bytes" self.wfile.write(data) self.wfile.flush() @@ -277,7 +282,7 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): finally: exc_info = None elif headers_set: - raise AssertionError('Headers already set') + raise AssertionError("Headers already set") headers_set[:] = [status, response_headers] return write @@ -287,9 +292,9 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): for data in application_iter: write(data) if not headers_sent: - write(b'') + write(b"") finally: - if hasattr(application_iter, 'close'): + if hasattr(application_iter, "close"): application_iter.close() application_iter = None @@ -300,7 +305,8 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): except Exception: if self.server.passthrough_errors: raise - from werkzeug.debug.tbtools import get_current_traceback + from .debug.tbtools import get_current_traceback + traceback = get_current_traceback(ignore_system_exceptions=True) try: # if we haven't yet sent the headers but they are set @@ -310,8 +316,7 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): execute(InternalServerError()) except Exception: pass - self.server.log('error', 'Error on request:\n%s', - traceback.plaintext) + self.server.log("error", "Error on request:\n%s", traceback.plaintext) def handle(self): """Handles a request ignoring dropped connections.""" @@ -332,7 +337,7 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): later. It's the best we can do. """ # Windows does not provide SIGKILL, go with SIGTERM then. - sig = getattr(signal, 'SIGKILL', signal.SIGTERM) + sig = getattr(signal, "SIGKILL", signal.SIGTERM) # reloader active if is_running_from_reloader(): os.kill(os.getpid(), sig) @@ -358,19 +363,19 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): """Send the response header and log the response code.""" self.log_request(code) if message is None: - message = code in self.responses and self.responses[code][0] or '' - if self.request_version != 'HTTP/0.9': + message = code in self.responses and self.responses[code][0] or "" + if self.request_version != "HTTP/0.9": hdr = "%s %d %s\r\n" % (self.protocol_version, code, message) - self.wfile.write(hdr.encode('ascii')) + self.wfile.write(hdr.encode("ascii")) def version_string(self): return BaseHTTPRequestHandler.version_string(self).strip() def address_string(self): - if getattr(self, 'environ', None): - return self.environ['REMOTE_ADDR'] + if getattr(self, "environ", None): + return self.environ["REMOTE_ADDR"] elif not self.client_address: - return '<local>' + return "<local>" elif isinstance(self.client_address, str): return self.client_address else: @@ -379,7 +384,7 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): def port_integer(self): return self.client_address[1] - def log_request(self, code='-', size='-'): + def log_request(self, code="-", size="-"): try: path = uri_to_iri(self.path) msg = "%s %s %s" % (self.command, path, self.request_version) @@ -392,33 +397,35 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): if termcolor: color = termcolor.colored - if code[0] == '1': # 1xx - Informational - msg = color(msg, attrs=['bold']) - elif code[0] == '2': # 2xx - Success - msg = color(msg, color='white') - elif code == '304': # 304 - Resource Not Modified - msg = color(msg, color='cyan') - elif code[0] == '3': # 3xx - Redirection - msg = color(msg, color='green') - elif code == '404': # 404 - Resource Not Found - msg = color(msg, color='yellow') - elif code[0] == '4': # 4xx - Client Error - msg = color(msg, color='red', attrs=['bold']) - else: # 5xx, or any other response - msg = color(msg, color='magenta', attrs=['bold']) - - self.log('info', '"%s" %s %s', msg, code, size) + if code[0] == "1": # 1xx - Informational + msg = color(msg, attrs=["bold"]) + elif code[0] == "2": # 2xx - Success + msg = color(msg, color="white") + elif code == "304": # 304 - Resource Not Modified + msg = color(msg, color="cyan") + elif code[0] == "3": # 3xx - Redirection + msg = color(msg, color="green") + elif code == "404": # 404 - Resource Not Found + msg = color(msg, color="yellow") + elif code[0] == "4": # 4xx - Client Error + msg = color(msg, color="red", attrs=["bold"]) + else: # 5xx, or any other response + msg = color(msg, color="magenta", attrs=["bold"]) + + self.log("info", '"%s" %s %s', msg, code, size) def log_error(self, *args): - self.log('error', *args) + self.log("error", *args) def log_message(self, format, *args): - self.log('info', format, *args) + self.log("info", format, *args) def log(self, type, message, *args): - _log(type, '%s - - [%s] %s\n' % (self.address_string(), - self.log_date_time_string(), - message % args)) + _log( + type, + "%s - - [%s] %s\n" + % (self.address_string(), self.log_date_time_string(), message % args), + ) def get_header_items(self): """ @@ -434,15 +441,18 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object): :return: List of tuples containing header hey/value pairs """ if PY2: - # For Python 2, process the headers manually according to W3C RFC 2616 Section 4.2 + # For Python 2, process the headers manually according to + # W3C RFC 2616 Section 4.2. items = [] for header in self.headers.headers: - # Remove the \n\r from the header and split on the : to get the field name and value + # Remove "\n\r" from the header and split on ":" to get + # the field name and value. key, value = header[0:-2].split(":", 1) - # Add the key and the value once stripped of leading white space. The specification - # allows for stripping trailing white space but the Python 3 code does not strip - # trailing white space. Therefore, trailing space will be left as is to match the - # Python 3 behavior + # Add the key and the value once stripped of leading + # white space. The specification allows for stripping + # trailing white space but the Python 3 code does not + # strip trailing white space. Therefore, trailing space + # will be left as is to match the Python 3 behavior. items.append((key, value.lstrip())) else: items = self.headers.items() @@ -456,11 +466,12 @@ BaseRequestHandler = WSGIRequestHandler def generate_adhoc_ssl_pair(cn=None): from random import random + crypto = _get_openssl_crypto_module() # pretty damn sure that this is not actually accepted by anyone if cn is None: - cn = '*' + cn = "*" cert = crypto.X509() cert.set_serial_number(int(random() * sys.maxsize)) @@ -469,7 +480,7 @@ def generate_adhoc_ssl_pair(cn=None): subject = cert.get_subject() subject.CN = cn - subject.O = 'Dummy Certificate' # noqa: E741 + subject.O = "Dummy Certificate" # noqa: E741 issuer = cert.get_issuer() issuer.CN = subject.CN @@ -478,7 +489,7 @@ def generate_adhoc_ssl_pair(cn=None): pkey = crypto.PKey() pkey.generate_key(crypto.TYPE_RSA, 2048) cert.set_pubkey(pkey) - cert.sign(pkey, 'sha256') + cert.sign(pkey, "sha256") return cert, pkey @@ -502,16 +513,17 @@ def make_ssl_devcert(base_path, host=None, cn=None): :param cn: the `CN` to use. """ from OpenSSL import crypto + if host is not None: - cn = '*.%s/CN=%s' % (host, host) + cn = "*.%s/CN=%s" % (host, host) cert, pkey = generate_adhoc_ssl_pair(cn=cn) - cert_file = base_path + '.crt' - pkey_file = base_path + '.key' + cert_file = base_path + ".crt" + pkey_file = base_path + ".key" - with open(cert_file, 'wb') as f: + with open(cert_file, "wb") as f: f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert)) - with open(pkey_file, 'wb') as f: + with open(pkey_file, "wb") as f: f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, pkey)) return cert_file, pkey_file @@ -557,8 +569,8 @@ def load_ssl_context(cert_file, pkey_file=None, protocol=None): class _SSLContext(object): - '''A dummy class with a small subset of Python3's ``ssl.SSLContext``, only - intended to be used with and by Werkzeug.''' + """A dummy class with a small subset of Python3's ``ssl.SSLContext``, only + intended to be used with and by Werkzeug.""" def __init__(self, protocol): self._protocol = protocol @@ -572,9 +584,13 @@ class _SSLContext(object): self._password = password def wrap_socket(self, sock, **kwargs): - return ssl.wrap_socket(sock, keyfile=self._keyfile, - certfile=self._certfile, - ssl_version=self._protocol, **kwargs) + return ssl.wrap_socket( + sock, + keyfile=self._keyfile, + certfile=self._certfile, + ssl_version=self._protocol, + **kwargs + ) def is_ssl_error(error=None): @@ -582,6 +598,7 @@ def is_ssl_error(error=None): exc_types = (ssl.SSLError,) try: from OpenSSL.SSL import Error + exc_types += (Error,) except ImportError: pass @@ -606,9 +623,9 @@ def select_address_family(host, port): # return info[0][0] # except socket.gaierror: # pass - if host.startswith('unix://'): + if host.startswith("unix://"): return socket.AF_UNIX - elif ':' in host and hasattr(socket, 'AF_INET6'): + elif ":" in host and hasattr(socket, "AF_INET6"): return socket.AF_INET6 return socket.AF_INET @@ -617,10 +634,11 @@ def get_sockaddr(host, port, family): """Return a fully qualified socket address that can be passed to :func:`socket.bind`.""" if family == af_unix: - return host.split('://', 1)[1] + return host.split("://", 1)[1] try: res = socket.getaddrinfo( - host, port, family, socket.SOCK_STREAM, socket.IPPROTO_TCP) + host, port, family, socket.SOCK_STREAM, socket.IPPROTO_TCP + ) except socket.gaierror: return host, port return res[0][4] @@ -629,13 +647,20 @@ def get_sockaddr(host, port, family): class BaseWSGIServer(HTTPServer, object): """Simple single-threaded, single-process WSGI server.""" + multithread = False multiprocess = False request_queue_size = LISTEN_QUEUE def __init__( - self, host, port, app, handler=None, passthrough_errors=False, - ssl_context=None, fd=None + self, + host, + port, + app, + handler=None, + passthrough_errors=False, + ssl_context=None, + fd=None, ): if handler is None: handler = WSGIRequestHandler @@ -643,17 +668,13 @@ class BaseWSGIServer(HTTPServer, object): self.address_family = select_address_family(host, port) if fd is not None: - real_sock = socket.fromfd( - fd, self.address_family, socket.SOCK_STREAM) + real_sock = socket.fromfd(fd, self.address_family, socket.SOCK_STREAM) port = 0 server_address = get_sockaddr(host, int(port), self.address_family) # remove socket file if it already exists - if ( - self.address_family == af_unix - and os.path.exists(server_address) - ): + if self.address_family == af_unix and os.path.exists(server_address): os.unlink(server_address) HTTPServer.__init__(self, server_address, handler) @@ -672,7 +693,7 @@ class BaseWSGIServer(HTTPServer, object): if ssl_context is not None: if isinstance(ssl_context, tuple): ssl_context = load_ssl_context(*ssl_context) - if ssl_context == 'adhoc': + if ssl_context == "adhoc": ssl_context = generate_adhoc_ssl_context() # If we are on Python 2 the return value from socket.fromfd # is an internal socket object but what we need for ssl wrap @@ -714,6 +735,7 @@ class BaseWSGIServer(HTTPServer, object): class ThreadedWSGIServer(ThreadingMixIn, BaseWSGIServer): """A WSGI server that does threading.""" + multithread = True daemon_threads = True @@ -721,35 +743,63 @@ class ThreadedWSGIServer(ThreadingMixIn, BaseWSGIServer): class ForkingWSGIServer(ForkingMixIn, BaseWSGIServer): """A WSGI server that does forking.""" + multiprocess = True - def __init__(self, host, port, app, processes=40, handler=None, - passthrough_errors=False, ssl_context=None, fd=None): + def __init__( + self, + host, + port, + app, + processes=40, + handler=None, + passthrough_errors=False, + ssl_context=None, + fd=None, + ): if not can_fork: - raise ValueError('Your platform does not support forking.') - BaseWSGIServer.__init__(self, host, port, app, handler, - passthrough_errors, ssl_context, fd) + raise ValueError("Your platform does not support forking.") + BaseWSGIServer.__init__( + self, host, port, app, handler, passthrough_errors, ssl_context, fd + ) self.max_children = processes -def make_server(host=None, port=None, app=None, threaded=False, processes=1, - request_handler=None, passthrough_errors=False, - ssl_context=None, fd=None): +def make_server( + host=None, + port=None, + app=None, + threaded=False, + processes=1, + request_handler=None, + passthrough_errors=False, + ssl_context=None, + fd=None, +): """Create a new server instance that is either threaded, or forks or just processes one request after another. """ if threaded and processes > 1: - raise ValueError("cannot have a multithreaded and " - "multi process server.") + raise ValueError("cannot have a multithreaded and multi process server.") elif threaded: - return ThreadedWSGIServer(host, port, app, request_handler, - passthrough_errors, ssl_context, fd=fd) + return ThreadedWSGIServer( + host, port, app, request_handler, passthrough_errors, ssl_context, fd=fd + ) elif processes > 1: - return ForkingWSGIServer(host, port, app, processes, request_handler, - passthrough_errors, ssl_context, fd=fd) + return ForkingWSGIServer( + host, + port, + app, + processes, + request_handler, + passthrough_errors, + ssl_context, + fd=fd, + ) else: - return BaseWSGIServer(host, port, app, request_handler, - passthrough_errors, ssl_context, fd=fd) + return BaseWSGIServer( + host, port, app, request_handler, passthrough_errors, ssl_context, fd=fd + ) def is_running_from_reloader(): @@ -758,15 +808,26 @@ def is_running_from_reloader(): .. versionadded:: 0.10 """ - return os.environ.get('WERKZEUG_RUN_MAIN') == 'true' - - -def run_simple(hostname, port, application, use_reloader=False, - use_debugger=False, use_evalex=True, - extra_files=None, reloader_interval=1, - reloader_type='auto', threaded=False, - processes=1, request_handler=None, static_files=None, - passthrough_errors=False, ssl_context=None): + return os.environ.get("WERKZEUG_RUN_MAIN") == "true" + + +def run_simple( + hostname, + port, + application, + use_reloader=False, + use_debugger=False, + use_evalex=True, + extra_files=None, + reloader_interval=1, + reloader_type="auto", + threaded=False, + processes=1, + request_handler=None, + static_files=None, + passthrough_errors=False, + ssl_context=None, +): """Start a WSGI application. Optional features include a reloader, multithreading and fork support. @@ -837,36 +898,50 @@ def run_simple(hostname, port, application, use_reloader=False, to disable SSL (which is the default). """ if not isinstance(port, int): - raise TypeError('port must be an integer') + raise TypeError("port must be an integer") if use_debugger: - from werkzeug.debug import DebuggedApplication + from .debug import DebuggedApplication + application = DebuggedApplication(application, use_evalex) if static_files: - from werkzeug.wsgi import SharedDataMiddleware + from .middleware.shared_data import SharedDataMiddleware + application = SharedDataMiddleware(application, static_files) def log_startup(sock): - display_hostname = hostname if hostname not in ('', '*') else 'localhost' - quit_msg = '(Press CTRL+C to quit)' + display_hostname = hostname if hostname not in ("", "*") else "localhost" + quit_msg = "(Press CTRL+C to quit)" if sock.family == af_unix: - _log('info', ' * Running on %s %s', display_hostname, quit_msg) + _log("info", " * Running on %s %s", display_hostname, quit_msg) else: - if ':' in display_hostname: - display_hostname = '[%s]' % display_hostname + if ":" in display_hostname: + display_hostname = "[%s]" % display_hostname port = sock.getsockname()[1] - _log('info', ' * Running on %s://%s:%d/ %s', - 'http' if ssl_context is None else 'https', - display_hostname, port, quit_msg) + _log( + "info", + " * Running on %s://%s:%d/ %s", + "http" if ssl_context is None else "https", + display_hostname, + port, + quit_msg, + ) def inner(): try: - fd = int(os.environ['WERKZEUG_SERVER_FD']) + fd = int(os.environ["WERKZEUG_SERVER_FD"]) except (LookupError, ValueError): fd = None - srv = make_server(hostname, port, application, threaded, - processes, request_handler, - passthrough_errors, ssl_context, - fd=fd) + srv = make_server( + hostname, + port, + application, + threaded, + processes, + request_handler, + passthrough_errors, + ssl_context, + fd=fd, + ) if fd is None: log_startup(srv.socket) srv.serve_forever() @@ -877,9 +952,11 @@ def run_simple(hostname, port, application, use_reloader=False, # port is actually available. if not is_running_from_reloader(): if port == 0 and not can_open_by_fd: - raise ValueError('Cannot bind to a random port with enabled ' - 'reloader if the Python interpreter does ' - 'not support socket opening by fd.') + raise ValueError( + "Cannot bind to a random port with enabled " + "reloader if the Python interpreter does " + "not support socket opening by fd." + ) # Create and destroy a socket so that any exceptions are # raised before we spawn a separate Python interpreter and @@ -889,26 +966,26 @@ def run_simple(hostname, port, application, use_reloader=False, s = socket.socket(address_family, socket.SOCK_STREAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind(server_address) - if hasattr(s, 'set_inheritable'): + if hasattr(s, "set_inheritable"): s.set_inheritable(True) # If we can open the socket by file descriptor, then we can just # reuse this one and our socket will survive the restarts. if can_open_by_fd: - os.environ['WERKZEUG_SERVER_FD'] = str(s.fileno()) + os.environ["WERKZEUG_SERVER_FD"] = str(s.fileno()) s.listen(LISTEN_QUEUE) log_startup(s) else: s.close() if address_family == af_unix: - _log('info', "Unlinking %s" % server_address) + _log("info", "Unlinking %s" % server_address) os.unlink(server_address) # Do not use relative imports, otherwise "python -m werkzeug.serving" # breaks. - from werkzeug._reloader import run_with_reloader - run_with_reloader(inner, extra_files, reloader_interval, - reloader_type) + from ._reloader import run_with_reloader + + run_with_reloader(inner, extra_files, reloader_interval, reloader_type) else: inner() @@ -916,46 +993,63 @@ def run_simple(hostname, port, application, use_reloader=False, def run_with_reloader(*args, **kwargs): # People keep using undocumented APIs. Do not use this function # please, we do not guarantee that it continues working. - from werkzeug._reloader import run_with_reloader + from ._reloader import run_with_reloader + return run_with_reloader(*args, **kwargs) def main(): - '''A simple command-line interface for :py:func:`run_simple`.''' + """A simple command-line interface for :py:func:`run_simple`.""" # in contrast to argparse, this works at least under Python < 2.7 import optparse - from werkzeug.utils import import_string - - parser = optparse.OptionParser( - usage='Usage: %prog [options] app_module:app_object') - parser.add_option('-b', '--bind', dest='address', - help='The hostname:port the app should listen on.') - parser.add_option('-d', '--debug', dest='use_debugger', - action='store_true', default=False, - help='Use Werkzeug\'s debugger.') - parser.add_option('-r', '--reload', dest='use_reloader', - action='store_true', default=False, - help='Reload Python process if modules change.') + from .utils import import_string + + parser = optparse.OptionParser(usage="Usage: %prog [options] app_module:app_object") + parser.add_option( + "-b", + "--bind", + dest="address", + help="The hostname:port the app should listen on.", + ) + parser.add_option( + "-d", + "--debug", + dest="use_debugger", + action="store_true", + default=False, + help="Use Werkzeug's debugger.", + ) + parser.add_option( + "-r", + "--reload", + dest="use_reloader", + action="store_true", + default=False, + help="Reload Python process if modules change.", + ) options, args = parser.parse_args() hostname, port = None, None if options.address: - address = options.address.split(':') + address = options.address.split(":") hostname = address[0] if len(address) > 1: port = address[1] if len(args) != 1: - sys.stdout.write('No application supplied, or too much. See --help\n') + sys.stdout.write("No application supplied, or too much. See --help\n") sys.exit(1) app = import_string(args[0]) run_simple( - hostname=(hostname or '127.0.0.1'), port=int(port or 5000), - application=app, use_reloader=options.use_reloader, - use_debugger=options.use_debugger + hostname=(hostname or "127.0.0.1"), + port=int(port or 5000), + application=app, + use_reloader=options.use_reloader, + use_debugger=options.use_debugger, ) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/src/werkzeug/test.py b/src/werkzeug/test.py index 8f04ec21..9aeb4127 100644 --- a/src/werkzeug/test.py +++ b/src/werkzeug/test.py @@ -8,56 +8,69 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import sys import mimetypes -from time import time - -from random import random +import sys +from io import BytesIO from itertools import chain +from random import random from tempfile import TemporaryFile -from io import BytesIO +from time import time + +from ._compat import iteritems +from ._compat import iterlists +from ._compat import itervalues +from ._compat import make_literal_wrapper +from ._compat import reraise +from ._compat import string_types +from ._compat import text_type +from ._compat import to_bytes +from ._compat import wsgi_encoding_dance +from ._internal import _get_environ +from .datastructures import CallbackDict +from .datastructures import CombinedMultiDict +from .datastructures import EnvironHeaders +from .datastructures import FileMultiDict +from .datastructures import FileStorage +from .datastructures import Headers +from .datastructures import MultiDict +from .http import dump_cookie +from .http import dump_options_header +from .http import parse_options_header +from .urls import iri_to_uri +from .urls import url_encode +from .urls import url_fix +from .urls import url_parse +from .urls import url_unparse +from .urls import url_unquote +from .utils import get_content_type +from .wrappers import BaseRequest +from .wsgi import ClosingIterator +from .wsgi import get_current_url try: - from urllib2 import Request as U2Request -except ImportError: from urllib.request import Request as U2Request +except ImportError: + from urllib2 import Request as U2Request + try: from http.cookiejar import CookieJar -except ImportError: # Py2 +except ImportError: from cookielib import CookieJar -from werkzeug._compat import iterlists, iteritems, itervalues, to_bytes, \ - string_types, text_type, reraise, wsgi_encoding_dance, \ - make_literal_wrapper -from werkzeug._internal import _get_environ -from werkzeug.wrappers import BaseRequest -from werkzeug.urls import url_encode, url_fix, iri_to_uri, url_unquote, \ - url_unparse, url_parse -from werkzeug.wsgi import get_current_url, ClosingIterator -from werkzeug.utils import dump_cookie, get_content_type -from werkzeug.datastructures import ( - FileMultiDict, - MultiDict, - CombinedMultiDict, - Headers, - FileStorage, - CallbackDict, - EnvironHeaders, -) -from werkzeug.http import dump_options_header, parse_options_header - - -def stream_encode_multipart(values, use_tempfile=True, threshold=1024 * 500, - boundary=None, charset='utf-8'): + +def stream_encode_multipart( + values, use_tempfile=True, threshold=1024 * 500, boundary=None, charset="utf-8" +): """Encode a dict of values (either strings or file descriptors or :class:`FileStorage` objects.) into a multipart encoded string stored in a file descriptor. """ if boundary is None: - boundary = '---------------WerkzeugFormPart_%s%s' % (time(), random()) + boundary = "---------------WerkzeugFormPart_%s%s" % (time(), random()) _closure = [BytesIO(), 0, False] if use_tempfile: + def write_binary(string): stream, total_length, on_disk = _closure if on_disk: @@ -67,12 +80,13 @@ def stream_encode_multipart(values, use_tempfile=True, threshold=1024 * 500, if length + _closure[1] <= threshold: stream.write(string) else: - new_stream = TemporaryFile('wb+') + new_stream = TemporaryFile("wb+") new_stream.write(stream.getvalue()) new_stream.write(string) _closure[0] = new_stream _closure[2] = True _closure[1] = total_length + length + else: write_binary = _closure[0].write @@ -84,22 +98,22 @@ def stream_encode_multipart(values, use_tempfile=True, threshold=1024 * 500, for key, values in iterlists(values): for value in values: - write('--%s\r\nContent-Disposition: form-data; name="%s"' % - (boundary, key)) - reader = getattr(value, 'read', None) + write('--%s\r\nContent-Disposition: form-data; name="%s"' % (boundary, key)) + reader = getattr(value, "read", None) if reader is not None: - filename = getattr(value, 'filename', - getattr(value, 'name', None)) - content_type = getattr(value, 'content_type', None) + filename = getattr(value, "filename", getattr(value, "name", None)) + content_type = getattr(value, "content_type", None) if content_type is None: - content_type = filename and \ - mimetypes.guess_type(filename)[0] or \ - 'application/octet-stream' + content_type = ( + filename + and mimetypes.guess_type(filename)[0] + or "application/octet-stream" + ) if filename is not None: write('; filename="%s"\r\n' % filename) else: - write('\r\n') - write('Content-Type: %s\r\n\r\n' % content_type) + write("\r\n") + write("Content-Type: %s\r\n\r\n" % content_type) while 1: chunk = reader(16384) if not chunk: @@ -110,22 +124,23 @@ def stream_encode_multipart(values, use_tempfile=True, threshold=1024 * 500, value = str(value) value = to_bytes(value, charset) - write('\r\n\r\n') + write("\r\n\r\n") write_binary(value) - write('\r\n') - write('--%s--\r\n' % boundary) + write("\r\n") + write("--%s--\r\n" % boundary) length = int(_closure[0].tell()) _closure[0].seek(0) return _closure[0], length, boundary -def encode_multipart(values, boundary=None, charset='utf-8'): +def encode_multipart(values, boundary=None, charset="utf-8"): """Like `stream_encode_multipart` but returns a tuple in the form (``boundary``, ``data``) where data is a bytestring. """ stream, length, boundary = stream_encode_multipart( - values, use_tempfile=False, boundary=boundary, charset=charset) + values, use_tempfile=False, boundary=boundary, charset=charset + ) return boundary, stream.read() @@ -135,6 +150,7 @@ def File(fd, filename=None, mimetype=None): .. deprecated:: 0.5 """ from warnings import warn + warn( "'werkzeug.test.File' is deprecated as of version 0.5 and will" " be removed in version 1.0. Use 'EnvironBuilder' or" @@ -194,17 +210,16 @@ class _TestCookieJar(CookieJar): """ cvals = [] for cookie in self: - cvals.append('%s=%s' % (cookie.name, cookie.value)) + cvals.append("%s=%s" % (cookie.name, cookie.value)) if cvals: - environ['HTTP_COOKIE'] = '; '.join(cvals) + environ["HTTP_COOKIE"] = "; ".join(cvals) def extract_wsgi(self, environ, headers): """Extract the server's set-cookie headers as cookies into the cookie jar. """ self.extract_cookies( - _TestCookieResponse(headers), - U2Request(get_current_url(environ)), + _TestCookieResponse(headers), U2Request(get_current_url(environ)) ) @@ -307,7 +322,7 @@ class EnvironBuilder(object): """ #: the server protocol to use. defaults to HTTP/1.1 - server_protocol = 'HTTP/1.1' + server_protocol = "HTTP/1.1" #: the wsgi version to use. defaults to (1, 0) wsgi_version = (1, 0) @@ -316,21 +331,37 @@ class EnvironBuilder(object): request_class = BaseRequest import json + #: The serialization function used when ``json`` is passed. json_dumps = staticmethod(json.dumps) del json - def __init__(self, path='/', base_url=None, query_string=None, - method='GET', input_stream=None, content_type=None, - content_length=None, errors_stream=None, multithread=False, - multiprocess=False, run_once=False, headers=None, data=None, - environ_base=None, environ_overrides=None, charset='utf-8', - mimetype=None, json=None): + def __init__( + self, + path="/", + base_url=None, + query_string=None, + method="GET", + input_stream=None, + content_type=None, + content_length=None, + errors_stream=None, + multithread=False, + multiprocess=False, + run_once=False, + headers=None, + data=None, + environ_base=None, + environ_overrides=None, + charset="utf-8", + mimetype=None, + json=None, + ): path_s = make_literal_wrapper(path) - if query_string is not None and path_s('?') in path: - raise ValueError('Query string is defined in the path and as an argument') - if query_string is None and path_s('?') in path: - path, query_string = path.split(path_s('?'), 1) + if query_string is not None and path_s("?") in path: + raise ValueError("Query string is defined in the path and as an argument") + if query_string is None and path_s("?") in path: + path, query_string = path.split(path_s("?"), 1) self.charset = charset self.path = iri_to_uri(path) if base_url is not None: @@ -375,8 +406,8 @@ class EnvironBuilder(object): if data: if input_stream is not None: - raise TypeError('can\'t provide input stream and data') - if hasattr(data, 'read'): + raise TypeError("can't provide input stream and data") + if hasattr(data, "read"): data = data.read() if isinstance(data, text_type): data = data.encode(self.charset) @@ -386,8 +417,7 @@ class EnvironBuilder(object): self.content_length = len(data) else: for key, value in _iter_data(data): - if isinstance(value, (tuple, dict)) or \ - hasattr(value, 'read'): + if isinstance(value, (tuple, dict)) or hasattr(value, "read"): self._add_file_from_data(key, value) else: self.form.setlistdefault(key).append(value) @@ -406,9 +436,7 @@ class EnvironBuilder(object): out = { "path": environ["PATH_INFO"], "base_url": cls._make_base_url( - environ["wsgi.url_scheme"], - headers.pop("Host"), - environ["SCRIPT_NAME"], + environ["wsgi.url_scheme"], headers.pop("Host"), environ["SCRIPT_NAME"] ), "query_string": environ["QUERY_STRING"], "method": environ["REQUEST_METHOD"], @@ -430,6 +458,7 @@ class EnvironBuilder(object): self.files.add_file(key, *value) elif isinstance(value, dict): from warnings import warn + warn( "Passing a dict as file data is deprecated as of" " version 0.5 and will be removed in version 1.0. Use" @@ -438,9 +467,9 @@ class EnvironBuilder(object): stacklevel=2, ) value = dict(value) - mimetype = value.pop('mimetype', None) + mimetype = value.pop("mimetype", None) if mimetype is not None: - value['content_type'] = mimetype + value["content_type"] = mimetype self.files.add_file(key, **value) else: self.files.add_file(key, value) @@ -459,90 +488,100 @@ class EnvironBuilder(object): @base_url.setter def base_url(self, value): if value is None: - scheme = 'http' - netloc = 'localhost' - script_root = '' + scheme = "http" + netloc = "localhost" + script_root = "" else: scheme, netloc, script_root, qs, anchor = url_parse(value) if qs or anchor: - raise ValueError('base url must not contain a query string ' - 'or fragment') - self.script_root = script_root.rstrip('/') + raise ValueError("base url must not contain a query string or fragment") + self.script_root = script_root.rstrip("/") self.host = netloc self.url_scheme = scheme def _get_content_type(self): - ct = self.headers.get('Content-Type') + ct = self.headers.get("Content-Type") if ct is None and not self._input_stream: if self._files: - return 'multipart/form-data' + return "multipart/form-data" elif self._form: - return 'application/x-www-form-urlencoded' + return "application/x-www-form-urlencoded" return None return ct def _set_content_type(self, value): if value is None: - self.headers.pop('Content-Type', None) + self.headers.pop("Content-Type", None) else: - self.headers['Content-Type'] = value - - content_type = property(_get_content_type, _set_content_type, doc=''' - The content type for the request. Reflected from and to the - :attr:`headers`. Do not set if you set :attr:`files` or - :attr:`form` for auto detection.''') + self.headers["Content-Type"] = value + + content_type = property( + _get_content_type, + _set_content_type, + doc="""The content type for the request. Reflected from and to + the :attr:`headers`. Do not set if you set :attr:`files` or + :attr:`form` for auto detection.""", + ) del _get_content_type, _set_content_type def _get_content_length(self): - return self.headers.get('Content-Length', type=int) + return self.headers.get("Content-Length", type=int) def _get_mimetype(self): ct = self.content_type if ct: - return ct.split(';')[0].strip() + return ct.split(";")[0].strip() def _set_mimetype(self, value): self.content_type = get_content_type(value, self.charset) def _get_mimetype_params(self): def on_update(d): - self.headers['Content-Type'] = \ - dump_options_header(self.mimetype, d) - d = parse_options_header(self.headers.get('content-type', ''))[1] + self.headers["Content-Type"] = dump_options_header(self.mimetype, d) + + d = parse_options_header(self.headers.get("content-type", ""))[1] return CallbackDict(d, on_update) - mimetype = property(_get_mimetype, _set_mimetype, doc=''' - The mimetype (content type without charset etc.) + mimetype = property( + _get_mimetype, + _set_mimetype, + doc="""The mimetype (content type without charset etc.) .. versionadded:: 0.14 - ''') - mimetype_params = property(_get_mimetype_params, doc=''' - The mimetype parameters as dict. For example if the content - type is ``text/html; charset=utf-8`` the params would be + """, + ) + mimetype_params = property( + _get_mimetype_params, + doc=""" The mimetype parameters as dict. For example if the + content type is ``text/html; charset=utf-8`` the params would be ``{'charset': 'utf-8'}``. .. versionadded:: 0.14 - ''') + """, + ) del _get_mimetype, _set_mimetype, _get_mimetype_params def _set_content_length(self, value): if value is None: - self.headers.pop('Content-Length', None) + self.headers.pop("Content-Length", None) else: - self.headers['Content-Length'] = str(value) + self.headers["Content-Length"] = str(value) - content_length = property(_get_content_length, _set_content_length, doc=''' - The content length as integer. Reflected from and to the + content_length = property( + _get_content_length, + _set_content_length, + doc="""The content length as integer. Reflected from and to the :attr:`headers`. Do not set if you set :attr:`files` or - :attr:`form` for auto detection.''') + :attr:`form` for auto detection.""", + ) del _get_content_length, _set_content_length - def form_property(name, storage, doc): - key = '_' + name + def form_property(name, storage, doc): # noqa: B902 + key = "_" + name def getter(self): if self._input_stream is not None: - raise AttributeError('an input stream is defined') + raise AttributeError("an input stream is defined") rv = getattr(self, key) if rv is None: rv = storage() @@ -553,14 +592,17 @@ class EnvironBuilder(object): def setter(self, value): self._input_stream = None setattr(self, key, value) + return property(getter, setter, doc=doc) - form = form_property('form', MultiDict, doc=''' - A :class:`MultiDict` of form values.''') - files = form_property('files', FileMultiDict, doc=''' - A :class:`FileMultiDict` of uploaded files. You can use the - :meth:`~FileMultiDict.add_file` method to add new files to the - dict.''') + form = form_property("form", MultiDict, doc="A :class:`MultiDict` of form values.") + files = form_property( + "files", + FileMultiDict, + doc="""A :class:`FileMultiDict` of uploaded files. You can use + the :meth:`~FileMultiDict.add_file` method to add new files to + the dict.""", + ) del form_property def _get_input_stream(self): @@ -570,30 +612,36 @@ class EnvironBuilder(object): self._input_stream = value self._form = self._files = None - input_stream = property(_get_input_stream, _set_input_stream, doc=''' - An optional input stream. If you set this it will clear - :attr:`form` and :attr:`files`.''') + input_stream = property( + _get_input_stream, + _set_input_stream, + doc="""An optional input stream. If you set this it will clear + :attr:`form` and :attr:`files`.""", + ) del _get_input_stream, _set_input_stream def _get_query_string(self): if self._query_string is None: if self._args is not None: return url_encode(self._args, charset=self.charset) - return '' + return "" return self._query_string def _set_query_string(self, value): self._query_string = value self._args = None - query_string = property(_get_query_string, _set_query_string, doc=''' - The query string. If you set this to a string :attr:`args` will - no longer be available.''') + query_string = property( + _get_query_string, + _set_query_string, + doc="""The query string. If you set this to a string + :attr:`args` will no longer be available.""", + ) del _get_query_string, _set_query_string def _get_args(self): if self._query_string is not None: - raise AttributeError('a query string is defined') + raise AttributeError("a query string is defined") if self._args is None: self._args = MultiDict() return self._args @@ -602,22 +650,23 @@ class EnvironBuilder(object): self._query_string = None self._args = value - args = property(_get_args, _set_args, doc=''' - The URL arguments as :class:`MultiDict`.''') + args = property( + _get_args, _set_args, doc="The URL arguments as :class:`MultiDict`." + ) del _get_args, _set_args @property def server_name(self): """The server name (read-only, use :attr:`host` to set)""" - return self.host.split(':', 1)[0] + return self.host.split(":", 1)[0] @property def server_port(self): """The server port as integer (read-only, use :attr:`host` to set)""" - pieces = self.host.split(':', 1) + pieces = self.host.split(":", 1) if len(pieces) == 2 and pieces[1].isdigit(): return int(pieces[1]) - elif self.url_scheme == 'https': + elif self.url_scheme == "https": return 443 return 80 @@ -665,15 +714,16 @@ class EnvironBuilder(object): end_pos = input_stream.tell() input_stream.seek(start_pos) content_length = end_pos - start_pos - elif mimetype == 'multipart/form-data': + elif mimetype == "multipart/form-data": values = CombinedMultiDict([self.form, self.files]) - input_stream, content_length, boundary = \ - stream_encode_multipart(values, charset=self.charset) + input_stream, content_length, boundary = stream_encode_multipart( + values, charset=self.charset + ) content_type = mimetype + '; boundary="%s"' % boundary - elif mimetype == 'application/x-www-form-urlencoded': + elif mimetype == "application/x-www-form-urlencoded": # XXX: py2v3 review values = url_encode(self.form, charset=self.charset) - values = values.encode('ascii') + values = values.encode("ascii") content_length = len(values) input_stream = BytesIO(values) else: @@ -688,40 +738,42 @@ class EnvironBuilder(object): qs = wsgi_encoding_dance(self.query_string) - result.update({ - 'REQUEST_METHOD': self.method, - 'SCRIPT_NAME': _path_encode(self.script_root), - 'PATH_INFO': _path_encode(self.path), - 'QUERY_STRING': qs, - # Non-standard, added by mod_wsgi, uWSGI - "REQUEST_URI": wsgi_encoding_dance(self.path), - # Non-standard, added by gunicorn - "RAW_URI": wsgi_encoding_dance(self.path), - 'SERVER_NAME': self.server_name, - 'SERVER_PORT': str(self.server_port), - 'HTTP_HOST': self.host, - 'SERVER_PROTOCOL': self.server_protocol, - 'wsgi.version': self.wsgi_version, - 'wsgi.url_scheme': self.url_scheme, - 'wsgi.input': input_stream, - 'wsgi.errors': self.errors_stream, - 'wsgi.multithread': self.multithread, - 'wsgi.multiprocess': self.multiprocess, - 'wsgi.run_once': self.run_once - }) + result.update( + { + "REQUEST_METHOD": self.method, + "SCRIPT_NAME": _path_encode(self.script_root), + "PATH_INFO": _path_encode(self.path), + "QUERY_STRING": qs, + # Non-standard, added by mod_wsgi, uWSGI + "REQUEST_URI": wsgi_encoding_dance(self.path), + # Non-standard, added by gunicorn + "RAW_URI": wsgi_encoding_dance(self.path), + "SERVER_NAME": self.server_name, + "SERVER_PORT": str(self.server_port), + "HTTP_HOST": self.host, + "SERVER_PROTOCOL": self.server_protocol, + "wsgi.version": self.wsgi_version, + "wsgi.url_scheme": self.url_scheme, + "wsgi.input": input_stream, + "wsgi.errors": self.errors_stream, + "wsgi.multithread": self.multithread, + "wsgi.multiprocess": self.multiprocess, + "wsgi.run_once": self.run_once, + } + ) headers = self.headers.copy() if content_type is not None: - result['CONTENT_TYPE'] = content_type + result["CONTENT_TYPE"] = content_type headers.set("Content-Type", content_type) if content_length is not None: - result['CONTENT_LENGTH'] = str(content_length) + result["CONTENT_LENGTH"] = str(content_length) headers.set("Content-Length", content_length) for key, value in headers.to_wsgi_list(): - result['HTTP_%s' % key.upper().replace('-', '_')] = value + result["HTTP_%s" % key.upper().replace("-", "_")] = value if self.environ_overrides: result.update(self.environ_overrides) @@ -740,15 +792,12 @@ class EnvironBuilder(object): class ClientRedirectError(Exception): - - """ - If a redirect loop is detected when using follow_redirects=True with + """If a redirect loop is detected when using follow_redirects=True with the :cls:`Client`, then this exception is raised. """ class Client(object): - """This class allows to send requests to a wrapped application. The response wrapper can be a class or factory function that takes @@ -781,8 +830,13 @@ class Client(object): The ``json`` parameter. """ - def __init__(self, application, response_wrapper=None, use_cookies=True, - allow_subdomain_redirects=False): + def __init__( + self, + application, + response_wrapper=None, + use_cookies=True, + allow_subdomain_redirects=False, + ): self.application = application self.response_wrapper = response_wrapper if use_cookies: @@ -791,24 +845,36 @@ class Client(object): self.cookie_jar = None self.allow_subdomain_redirects = allow_subdomain_redirects - def set_cookie(self, server_name, key, value='', max_age=None, - expires=None, path='/', domain=None, secure=None, - httponly=False, charset='utf-8'): + def set_cookie( + self, + server_name, + key, + value="", + max_age=None, + expires=None, + path="/", + domain=None, + secure=None, + httponly=False, + charset="utf-8", + ): """Sets a cookie in the client's cookie jar. The server name is required and has to match the one that is also passed to the open call. """ - assert self.cookie_jar is not None, 'cookies disabled' - header = dump_cookie(key, value, max_age, expires, path, domain, - secure, httponly, charset) - environ = create_environ(path, base_url='http://' + server_name) - headers = [('Set-Cookie', header)] + assert self.cookie_jar is not None, "cookies disabled" + header = dump_cookie( + key, value, max_age, expires, path, domain, secure, httponly, charset + ) + environ = create_environ(path, base_url="http://" + server_name) + headers = [("Set-Cookie", header)] self.cookie_jar.extract_wsgi(environ, headers) - def delete_cookie(self, server_name, key, path='/', domain=None): + def delete_cookie(self, server_name, key, path="/", domain=None): """Deletes a cookie in the test client.""" - self.set_cookie(server_name, key, expires=0, max_age=0, - path=path, domain=domain) + self.set_cookie( + server_name, key, expires=0, max_age=0, path=path, domain=domain + ) def run_wsgi_app(self, environ, buffered=False): """Runs the wrapped WSGI app with the given environment.""" @@ -826,7 +892,7 @@ class Client(object): scheme, netloc, path, qs, anchor = url_parse(new_location) builder = EnvironBuilder.from_environ(environ, query_string=qs) - to_name_parts = netloc.split(':', 1)[0].split(".") + to_name_parts = netloc.split(":", 1)[0].split(".") from_name_parts = builder.server_name.split(".") if to_name_parts != [""]: @@ -840,7 +906,7 @@ class Client(object): # Explain why a redirect to a different server name won't be followed. if to_name_parts != from_name_parts: - if to_name_parts[-len(from_name_parts):] == from_name_parts: + if to_name_parts[-len(from_name_parts) :] == from_name_parts: if not self.allow_subdomain_redirects: raise RuntimeError("Following subdomain redirects is not enabled.") else: @@ -849,9 +915,9 @@ class Client(object): path_parts = path.split("/") root_parts = builder.script_root.split("/") - if path_parts[:len(root_parts)] == root_parts: + if path_parts[: len(root_parts)] == root_parts: # Strip the script root from the path. - builder.path = path[len(builder.script_root):] + builder.path = path[len(builder.script_root) :] else: # The new location is not under the script root, so use the # whole path and clear the previous root. @@ -907,9 +973,9 @@ class Client(object): :param follow_redirects: Set this to True if the `Client` should follow HTTP redirects. """ - as_tuple = kwargs.pop('as_tuple', False) - buffered = kwargs.pop('buffered', False) - follow_redirects = kwargs.pop('follow_redirects', False) + as_tuple = kwargs.pop("as_tuple", False) + buffered = kwargs.pop("buffered", False) + follow_redirects = kwargs.pop("follow_redirects", False) environ = None if not kwargs and len(args) == 1: if isinstance(args[0], EnvironBuilder): @@ -941,15 +1007,13 @@ class Client(object): for _ in response[0]: pass - new_location = response[2]['location'] + new_location = response[2]["location"] new_redirect_entry = (new_location, status_code) if new_redirect_entry in redirect_chain: - raise ClientRedirectError('loop detected') + raise ClientRedirectError("loop detected") redirect_chain.append(new_redirect_entry) environ, response = self.resolve_redirect( - response, - new_location, - environ, buffered=buffered + response, new_location, environ, buffered=buffered ) if self.response_wrapper is not None: @@ -960,49 +1024,46 @@ class Client(object): def get(self, *args, **kw): """Like open but method is enforced to GET.""" - kw['method'] = 'GET' + kw["method"] = "GET" return self.open(*args, **kw) def patch(self, *args, **kw): """Like open but method is enforced to PATCH.""" - kw['method'] = 'PATCH' + kw["method"] = "PATCH" return self.open(*args, **kw) def post(self, *args, **kw): """Like open but method is enforced to POST.""" - kw['method'] = 'POST' + kw["method"] = "POST" return self.open(*args, **kw) def head(self, *args, **kw): """Like open but method is enforced to HEAD.""" - kw['method'] = 'HEAD' + kw["method"] = "HEAD" return self.open(*args, **kw) def put(self, *args, **kw): """Like open but method is enforced to PUT.""" - kw['method'] = 'PUT' + kw["method"] = "PUT" return self.open(*args, **kw) def delete(self, *args, **kw): """Like open but method is enforced to DELETE.""" - kw['method'] = 'DELETE' + kw["method"] = "DELETE" return self.open(*args, **kw) def options(self, *args, **kw): """Like open but method is enforced to OPTIONS.""" - kw['method'] = 'OPTIONS' + kw["method"] = "OPTIONS" return self.open(*args, **kw) def trace(self, *args, **kw): """Like open but method is enforced to TRACE.""" - kw['method'] = 'TRACE' + kw["method"] = "TRACE" return self.open(*args, **kw) def __repr__(self): - return '<%s %r>' % ( - self.__class__.__name__, - self.application - ) + return "<%s %r>" % (self.__class__.__name__, self.application) def create_environ(*args, **kwargs): @@ -1055,7 +1116,7 @@ def run_wsgi_app(app, environ, buffered=False): return buffer.append app_rv = app(environ, start_response) - close_func = getattr(app_rv, 'close', None) + close_func = getattr(app_rv, "close", None) app_iter = iter(app_rv) # when buffering we emit the close call early and convert the diff --git a/src/werkzeug/testapp.py b/src/werkzeug/testapp.py index 972c6ca1..8ea23bee 100644 --- a/src/werkzeug/testapp.py +++ b/src/werkzeug/testapp.py @@ -9,15 +9,19 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ +import base64 import os import sys -import werkzeug from textwrap import wrap -from werkzeug.wrappers import BaseRequest as Request, BaseResponse as Response -from werkzeug.utils import escape -import base64 -logo = Response(base64.b64decode(''' +import werkzeug +from .utils import escape +from .wrappers import BaseRequest as Request +from .wrappers import BaseResponse as Response + +logo = Response( + base64.b64decode( + """ R0lGODlhoACgAOMIAAEDACwpAEpCAGdgAJaKAM28AOnVAP3rAP///////// //////////////////////yH5BAEKAAgALAAAAACgAKAAAAT+EMlJq704680R+F0ojmRpnuj0rWnrv nB8rbRs33gu0bzu/0AObxgsGn3D5HHJbCUFyqZ0ukkSDlAidctNFg7gbI9LZlrBaHGtzAae0eloe25 @@ -51,10 +55,13 @@ UpyGlhjBUljyjHhWpf8OFaXwhp9O4T1gU9UeyPPa8A2l0p1kNqPXEVRm1AOs1oAGZU596t6SOR2mcB Oco1srWtkaVrMUzIErrKri85keKqRQYX9VX0/eAUK1hrSu6HMEX3Qh2sCh0q0D2CtnUqS4hj62sE/z aDs2Sg7MBS6xnQeooc2R2tC9YrKpEi9pLXfYXp20tDCpSP8rKlrD4axprb9u1Df5hSbz9QU0cRpfgn kiIzwKucd0wsEHlLpe5yHXuc6FrNelOl7pY2+11kTWx7VpRu97dXA3DO1vbkhcb4zyvERYajQgAADs -='''), mimetype='image/png') +=""" + ), + mimetype="image/png", +) -TEMPLATE = u'''\ +TEMPLATE = u"""\ <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN" "http://www.w3.org/TR/html4/loose.dtd"> <title>WSGI Information</title> @@ -130,24 +137,27 @@ TEMPLATE = u'''\ environments. <ul class="path">%(sys_path)s</ul> </div> -''' +""" def iter_sys_path(): - if os.name == 'posix': + if os.name == "posix": + def strip(x): - prefix = os.path.expanduser('~') + prefix = os.path.expanduser("~") if x.startswith(prefix): - x = '~' + x[len(prefix):] + x = "~" + x[len(prefix) :] return x + else: - strip = lambda x: x + + def strip(x): + return x cwd = os.path.abspath(os.getcwd()) for item in sys.path: path = os.path.join(cwd, item or os.path.curdir) - yield strip(os.path.normpath(path)), \ - not os.path.isdir(path), path != item + yield strip(os.path.normpath(path)), not os.path.isdir(path), path != item def render_testapp(req): @@ -156,51 +166,51 @@ def render_testapp(req): except ImportError: eggs = () else: - eggs = sorted(pkg_resources.working_set, - key=lambda x: x.project_name.lower()) + eggs = sorted(pkg_resources.working_set, key=lambda x: x.project_name.lower()) python_eggs = [] for egg in eggs: try: version = egg.version except (ValueError, AttributeError): - version = 'unknown' - python_eggs.append('<li>%s <small>[%s]</small>' % ( - escape(egg.project_name), - escape(version) - )) + version = "unknown" + python_eggs.append( + "<li>%s <small>[%s]</small>" % (escape(egg.project_name), escape(version)) + ) wsgi_env = [] - sorted_environ = sorted(req.environ.items(), - key=lambda x: repr(x[0]).lower()) + sorted_environ = sorted(req.environ.items(), key=lambda x: repr(x[0]).lower()) for key, value in sorted_environ: - wsgi_env.append('<tr><th>%s<td><code>%s</code>' % ( - escape(str(key)), - ' '.join(wrap(escape(repr(value)))) - )) + wsgi_env.append( + "<tr><th>%s<td><code>%s</code>" + % (escape(str(key)), " ".join(wrap(escape(repr(value))))) + ) sys_path = [] for item, virtual, expanded in iter_sys_path(): class_ = [] if virtual: - class_.append('virtual') + class_.append("virtual") if expanded: - class_.append('exp') - sys_path.append('<li%s>%s' % ( - ' class="%s"' % ' '.join(class_) if class_ else '', - escape(item) - )) - - return (TEMPLATE % { - 'python_version': '<br>'.join(escape(sys.version).splitlines()), - 'platform': escape(sys.platform), - 'os': escape(os.name), - 'api_version': sys.api_version, - 'byteorder': sys.byteorder, - 'werkzeug_version': werkzeug.__version__, - 'python_eggs': '\n'.join(python_eggs), - 'wsgi_env': '\n'.join(wsgi_env), - 'sys_path': '\n'.join(sys_path) - }).encode('utf-8') + class_.append("exp") + sys_path.append( + "<li%s>%s" + % (' class="%s"' % " ".join(class_) if class_ else "", escape(item)) + ) + + return ( + TEMPLATE + % { + "python_version": "<br>".join(escape(sys.version).splitlines()), + "platform": escape(sys.platform), + "os": escape(os.name), + "api_version": sys.api_version, + "byteorder": sys.byteorder, + "werkzeug_version": werkzeug.__version__, + "python_eggs": "\n".join(python_eggs), + "wsgi_env": "\n".join(wsgi_env), + "sys_path": "\n".join(sys_path), + } + ).encode("utf-8") def test_app(environ, start_response): @@ -218,13 +228,14 @@ def test_app(environ, start_response): the Python interpreter and the installed libraries. """ req = Request(environ, populate_request=False) - if req.args.get('resource') == 'logo': + if req.args.get("resource") == "logo": response = logo else: - response = Response(render_testapp(req), mimetype='text/html') + response = Response(render_testapp(req), mimetype="text/html") return response(environ, start_response) -if __name__ == '__main__': - from werkzeug.serving import run_simple - run_simple('localhost', 5000, test_app, use_reloader=True) +if __name__ == "__main__": + from .serving import run_simple + + run_simple("localhost", 5000, test_app, use_reloader=True) diff --git a/src/werkzeug/urls.py b/src/werkzeug/urls.py index 09ec5b0c..38e9e5ad 100644 --- a/src/werkzeug/urls.py +++ b/src/werkzeug/urls.py @@ -15,56 +15,53 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -from collections import namedtuple - import codecs import os import re +from collections import namedtuple -from werkzeug._compat import ( - PY2, - fix_tuple_repr, - implements_to_string, - make_literal_wrapper, - normalize_string_tuple, - text_type, - to_native, - to_unicode, - try_coerce_native, -) -from werkzeug._internal import _decode_idna, _encode_idna -from werkzeug.datastructures import MultiDict, iter_multi_items +from ._compat import fix_tuple_repr +from ._compat import implements_to_string +from ._compat import make_literal_wrapper +from ._compat import normalize_string_tuple +from ._compat import PY2 +from ._compat import text_type +from ._compat import to_native +from ._compat import to_unicode +from ._compat import try_coerce_native +from ._internal import _decode_idna +from ._internal import _encode_idna +from .datastructures import iter_multi_items +from .datastructures import MultiDict # A regular expression for what a valid schema looks like -_scheme_re = re.compile(r'^[a-zA-Z0-9+-.]+$') +_scheme_re = re.compile(r"^[a-zA-Z0-9+-.]+$") # Characters that are safe in any part of an URL. -_always_safe = frozenset(bytearray( - b"abcdefghijklmnopqrstuvwxyz" - b"ABCDEFGHIJKLMNOPQRSTUVWXYZ" - b"0123456789" - b"-._~" -)) - -_hexdigits = '0123456789ABCDEFabcdef' +_always_safe = frozenset( + bytearray( + b"abcdefghijklmnopqrstuvwxyz" + b"ABCDEFGHIJKLMNOPQRSTUVWXYZ" + b"0123456789" + b"-._~" + ) +) + +_hexdigits = "0123456789ABCDEFabcdef" _hextobyte = dict( - ((a + b).encode(), int(a + b, 16)) - for a in _hexdigits for b in _hexdigits + ((a + b).encode(), int(a + b, 16)) for a in _hexdigits for b in _hexdigits ) -_bytetohex = [ - ('%%%02X' % char).encode('ascii') for char in range(256) -] +_bytetohex = [("%%%02X" % char).encode("ascii") for char in range(256)] -_URLTuple = fix_tuple_repr(namedtuple( - '_URLTuple', - ['scheme', 'netloc', 'path', 'query', 'fragment'] -)) +_URLTuple = fix_tuple_repr( + namedtuple("_URLTuple", ["scheme", "netloc", "path", "query", "fragment"]) +) class BaseURL(_URLTuple): + """Superclass of :py:class:`URL` and :py:class:`BytesURL`.""" - '''Superclass of :py:class:`URL` and :py:class:`BytesURL`.''' __slots__ = () def replace(self, **kwargs): @@ -92,8 +89,8 @@ class BaseURL(_URLTuple): try: rv = _encode_idna(rv) except UnicodeError: - rv = rv.encode('ascii', 'ignore') - return to_native(rv, 'ascii', 'ignore') + rv = rv.encode("ascii", "ignore") + return to_native(rv, "ascii", "ignore") @property def port(self): @@ -169,19 +166,24 @@ class BaseURL(_URLTuple): def decode_netloc(self): """Decodes the netloc part into a string.""" - rv = _decode_idna(self.host or '') + rv = _decode_idna(self.host or "") - if ':' in rv: - rv = '[%s]' % rv + if ":" in rv: + rv = "[%s]" % rv port = self.port if port is not None: - rv = '%s:%d' % (rv, port) - auth = ':'.join(filter(None, [ - _url_unquote_legacy(self.raw_username or '', '/:%@'), - _url_unquote_legacy(self.raw_password or '', '/:%@'), - ])) + rv = "%s:%d" % (rv, port) + auth = ":".join( + filter( + None, + [ + _url_unquote_legacy(self.raw_username or "", "/:%@"), + _url_unquote_legacy(self.raw_password or "", "/:%@"), + ], + ) + ) if auth: - rv = '%s@%s' % (auth, rv) + rv = "%s@%s" % (auth, rv) return rv def to_uri_tuple(self): @@ -192,7 +194,7 @@ class BaseURL(_URLTuple): It's usually more interesting to directly call :meth:`iri_to_uri` which will return a string. """ - return url_parse(iri_to_uri(self).encode('ascii')) + return url_parse(iri_to_uri(self).encode("ascii")) def to_iri_tuple(self): """Returns a :class:`URL` tuple that holds a IRI. This will try @@ -223,42 +225,44 @@ class BaseURL(_URLTuple): supported. Defaults to ``None`` which is autodetect. """ - if self.scheme != 'file': + if self.scheme != "file": return None, None path = url_unquote(self.path) host = self.netloc or None if pathformat is None: - if os.name == 'nt': - pathformat = 'windows' + if os.name == "nt": + pathformat = "windows" else: - pathformat = 'posix' + pathformat = "posix" - if pathformat == 'windows': - if path[:1] == '/' and path[1:2].isalpha() and path[2:3] in '|:': - path = path[1:2] + ':' + path[3:] - windows_share = path[:3] in ('\\' * 3, '/' * 3) + if pathformat == "windows": + if path[:1] == "/" and path[1:2].isalpha() and path[2:3] in "|:": + path = path[1:2] + ":" + path[3:] + windows_share = path[:3] in ("\\" * 3, "/" * 3) import ntpath + path = ntpath.normpath(path) # Windows shared drives are represented as ``\\host\\directory``. # That results in a URL like ``file://///host/directory``, and a # path like ``///host/directory``. We need to special-case this # because the path contains the hostname. if windows_share and host is None: - parts = path.lstrip('\\').split('\\', 1) + parts = path.lstrip("\\").split("\\", 1) if len(parts) == 2: host, path = parts else: host = parts[0] - path = '' - elif pathformat == 'posix': + path = "" + elif pathformat == "posix": import posixpath + path = posixpath.normpath(path) else: - raise TypeError('Invalid path format %s' % repr(pathformat)) + raise TypeError("Invalid path format %s" % repr(pathformat)) - if host in ('127.0.0.1', '::1', 'localhost'): + if host in ("127.0.0.1", "::1", "localhost"): host = None return host, path @@ -291,7 +295,7 @@ class BaseURL(_URLTuple): return rv, None host = rv[1:idx] - rest = rv[idx + 1:] + rest = rv[idx + 1 :] if rest.startswith(self._colon): return host, rest[1:] return host, None @@ -299,81 +303,84 @@ class BaseURL(_URLTuple): @implements_to_string class URL(BaseURL): - """Represents a parsed URL. This behaves like a regular tuple but also has some extra attributes that give further insight into the URL. """ + __slots__ = () - _at = '@' - _colon = ':' - _lbracket = '[' - _rbracket = ']' + _at = "@" + _colon = ":" + _lbracket = "[" + _rbracket = "]" def __str__(self): return self.to_url() def encode_netloc(self): """Encodes the netloc part to an ASCII safe URL as bytes.""" - rv = self.ascii_host or '' - if ':' in rv: - rv = '[%s]' % rv + rv = self.ascii_host or "" + if ":" in rv: + rv = "[%s]" % rv port = self.port if port is not None: - rv = '%s:%d' % (rv, port) - auth = ':'.join(filter(None, [ - url_quote(self.raw_username or '', 'utf-8', 'strict', '/:%'), - url_quote(self.raw_password or '', 'utf-8', 'strict', '/:%'), - ])) + rv = "%s:%d" % (rv, port) + auth = ":".join( + filter( + None, + [ + url_quote(self.raw_username or "", "utf-8", "strict", "/:%"), + url_quote(self.raw_password or "", "utf-8", "strict", "/:%"), + ], + ) + ) if auth: - rv = '%s@%s' % (auth, rv) + rv = "%s@%s" % (auth, rv) return to_native(rv) - def encode(self, charset='utf-8', errors='replace'): + def encode(self, charset="utf-8", errors="replace"): """Encodes the URL to a tuple made out of bytes. The charset is only being used for the path, query and fragment. """ return BytesURL( - self.scheme.encode('ascii'), + self.scheme.encode("ascii"), self.encode_netloc(), self.path.encode(charset, errors), self.query.encode(charset, errors), - self.fragment.encode(charset, errors) + self.fragment.encode(charset, errors), ) class BytesURL(BaseURL): - """Represents a parsed URL in bytes.""" + __slots__ = () - _at = b'@' - _colon = b':' - _lbracket = b'[' - _rbracket = b']' + _at = b"@" + _colon = b":" + _lbracket = b"[" + _rbracket = b"]" def __str__(self): - return self.to_url().decode('utf-8', 'replace') + return self.to_url().decode("utf-8", "replace") def encode_netloc(self): """Returns the netloc unchanged as bytes.""" return self.netloc - def decode(self, charset='utf-8', errors='replace'): + def decode(self, charset="utf-8", errors="replace"): """Decodes the URL to a tuple made out of strings. The charset is only being used for the path, query and fragment. """ return URL( - self.scheme.decode('ascii'), + self.scheme.decode("ascii"), self.decode_netloc(), self.path.decode(charset, errors), self.query.decode(charset, errors), - self.fragment.decode(charset, errors) + self.fragment.decode(charset, errors), ) -_unquote_maps = { - frozenset(): _hextobyte, -} +_unquote_maps = {frozenset(): _hextobyte} def _unquote_to_bytes(string, unsafe=""): @@ -418,15 +425,14 @@ def _url_encode_impl(obj, charset, encode_keys, sort, key): key = text_type(key).encode(charset) if not isinstance(value, bytes): value = text_type(value).encode(charset) - yield _fast_url_quote_plus(key) + '=' + _fast_url_quote_plus(value) + yield _fast_url_quote_plus(key) + "=" + _fast_url_quote_plus(value) -def _url_unquote_legacy(value, unsafe=''): +def _url_unquote_legacy(value, unsafe=""): try: - return url_unquote(value, charset='utf-8', - errors='strict', unsafe=unsafe) + return url_unquote(value, charset="utf-8", errors="strict", unsafe=unsafe) except UnicodeError: - return url_unquote(value, charset='latin1', unsafe=unsafe) + return url_unquote(value, charset="latin1", unsafe=unsafe) def url_parse(url, scheme=None, allow_fragments=True): @@ -446,38 +452,39 @@ def url_parse(url, scheme=None, allow_fragments=True): is_text_based = isinstance(url, text_type) if scheme is None: - scheme = s('') - netloc = query = fragment = s('') - i = url.find(s(':')) - if i > 0 and _scheme_re.match(to_native(url[:i], errors='replace')): + scheme = s("") + netloc = query = fragment = s("") + i = url.find(s(":")) + if i > 0 and _scheme_re.match(to_native(url[:i], errors="replace")): # make sure "iri" is not actually a port number (in which case # "scheme" is really part of the path) - rest = url[i + 1:] - if not rest or any(c not in s('0123456789') for c in rest): + rest = url[i + 1 :] + if not rest or any(c not in s("0123456789") for c in rest): # not a port number scheme, url = url[:i].lower(), rest - if url[:2] == s('//'): + if url[:2] == s("//"): delim = len(url) - for c in s('/?#'): + for c in s("/?#"): wdelim = url.find(c, 2) if wdelim >= 0: delim = min(delim, wdelim) netloc, url = url[2:delim], url[delim:] - if (s('[') in netloc and s(']') not in netloc) or \ - (s(']') in netloc and s('[') not in netloc): - raise ValueError('Invalid IPv6 URL') + if (s("[") in netloc and s("]") not in netloc) or ( + s("]") in netloc and s("[") not in netloc + ): + raise ValueError("Invalid IPv6 URL") - if allow_fragments and s('#') in url: - url, fragment = url.split(s('#'), 1) - if s('?') in url: - url, query = url.split(s('?'), 1) + if allow_fragments and s("#") in url: + url, fragment = url.split(s("#"), 1) + if s("?") in url: + url, query = url.split(s("?"), 1) result_type = URL if is_text_based else BytesURL return result_type(scheme, netloc, url, query, fragment) -def _make_fast_url_quote(charset='utf-8', errors='strict', safe='/:', unsafe=''): +def _make_fast_url_quote(charset="utf-8", errors="strict", safe="/:", unsafe=""): """Precompile the translation table for a URL encoding function. Unlike :func:`url_quote`, the generated function only takes the @@ -495,12 +502,15 @@ def _make_fast_url_quote(charset='utf-8', errors='strict', safe='/:', unsafe='') unsafe = unsafe.encode(charset, errors) safe = (frozenset(bytearray(safe)) | _always_safe) - frozenset(bytearray(unsafe)) - table = [chr(c) if c in safe else '%%%02X' % c for c in range(256)] + table = [chr(c) if c in safe else "%%%02X" % c for c in range(256)] if not PY2: + def quote(string): return "".join([table[c] for c in string]) + else: + def quote(string): return "".join([table[c] for c in bytearray(string)]) @@ -508,14 +518,14 @@ def _make_fast_url_quote(charset='utf-8', errors='strict', safe='/:', unsafe='') _fast_url_quote = _make_fast_url_quote() -_fast_quote_plus = _make_fast_url_quote(safe=' ', unsafe='+') +_fast_quote_plus = _make_fast_url_quote(safe=" ", unsafe="+") def _fast_url_quote_plus(string): - return _fast_quote_plus(string).replace(' ', '+') + return _fast_quote_plus(string).replace(" ", "+") -def url_quote(string, charset='utf-8', errors='strict', safe='/:', unsafe=''): +def url_quote(string, charset="utf-8", errors="strict", safe="/:", unsafe=""): """URL encode a single string with a given encoding. :param s: the string to quote. @@ -544,7 +554,7 @@ def url_quote(string, charset='utf-8', errors='strict', safe='/:', unsafe=''): return to_native(bytes(rv)) -def url_quote_plus(string, charset='utf-8', errors='strict', safe=''): +def url_quote_plus(string, charset="utf-8", errors="strict", safe=""): """URL encode a single string with the given encoding and convert whitespace to "+". @@ -552,7 +562,7 @@ def url_quote_plus(string, charset='utf-8', errors='strict', safe=''): :param charset: The charset to be used. :param safe: An optional sequence of safe characters. """ - return url_quote(string, charset, errors, safe + ' ', '+').replace(' ', '+') + return url_quote(string, charset, errors, safe + " ", "+").replace(" ", "+") def url_unparse(components): @@ -562,31 +572,30 @@ def url_unparse(components): :param components: the parsed URL as tuple which should be converted into a URL string. """ - scheme, netloc, path, query, fragment = \ - normalize_string_tuple(components) + scheme, netloc, path, query, fragment = normalize_string_tuple(components) s = make_literal_wrapper(scheme) - url = s('') + url = s("") # We generally treat file:///x and file:/x the same which is also # what browsers seem to do. This also allows us to ignore a schema # register for netloc utilization or having to differenciate between # empty and missing netloc. - if netloc or (scheme and path.startswith(s('/'))): - if path and path[:1] != s('/'): - path = s('/') + path - url = s('//') + (netloc or s('')) + path + if netloc or (scheme and path.startswith(s("/"))): + if path and path[:1] != s("/"): + path = s("/") + path + url = s("//") + (netloc or s("")) + path elif path: url += path if scheme: - url = scheme + s(':') + url + url = scheme + s(":") + url if query: - url = url + s('?') + query + url = url + s("?") + query if fragment: - url = url + s('#') + fragment + url = url + s("#") + fragment return url -def url_unquote(string, charset='utf-8', errors='replace', unsafe=''): +def url_unquote(string, charset="utf-8", errors="replace", unsafe=""): """URL decode a single string with a given encoding. If the charset is set to `None` no unicode decoding is performed and raw bytes are returned. @@ -602,7 +611,7 @@ def url_unquote(string, charset='utf-8', errors='replace', unsafe=''): return rv -def url_unquote_plus(s, charset='utf-8', errors='replace'): +def url_unquote_plus(s, charset="utf-8", errors="replace"): """URL decode a single string with the given `charset` and decode "+" to whitespace. @@ -616,13 +625,13 @@ def url_unquote_plus(s, charset='utf-8', errors='replace'): :param errors: The error handling for the `charset` decoding. """ if isinstance(s, text_type): - s = s.replace(u'+', u' ') + s = s.replace(u"+", u" ") else: - s = s.replace(b'+', b' ') + s = s.replace(b"+", b" ") return url_unquote(s, charset, errors) -def url_fix(s, charset='utf-8'): +def url_fix(s, charset="utf-8"): r"""Sometimes you get an URL by a user that just isn't a real URL because it contains unsafe characters like ' ' and so on. This function can fix some of the problems in a similar way browsers handle data entered by the @@ -638,19 +647,18 @@ def url_fix(s, charset='utf-8'): # First step is to switch to unicode processing and to convert # backslashes (which are invalid in URLs anyways) to slashes. This is # consistent with what Chrome does. - s = to_unicode(s, charset, 'replace').replace('\\', '/') + s = to_unicode(s, charset, "replace").replace("\\", "/") # For the specific case that we look like a malformed windows URL # we want to fix this up manually: - if s.startswith('file://') and s[7:8].isalpha() and s[8:10] in (':/', '|/'): - s = 'file:///' + s[7:] + if s.startswith("file://") and s[7:8].isalpha() and s[8:10] in (":/", "|/"): + s = "file:///" + s[7:] url = url_parse(s) - path = url_quote(url.path, charset, safe='/%+$!*\'(),') - qs = url_quote_plus(url.query, charset, safe=':&%=+$!*\'(),') - anchor = url_quote_plus(url.fragment, charset, safe=':&%=+$!*\'(),') - return to_native(url_unparse((url.scheme, url.encode_netloc(), - path, qs, anchor))) + path = url_quote(url.path, charset, safe="/%+$!*'(),") + qs = url_quote_plus(url.query, charset, safe=":&%=+$!*'(),") + anchor = url_quote_plus(url.fragment, charset, safe=":&%=+$!*'(),") + return to_native(url_unparse((url.scheme, url.encode_netloc(), path, qs, anchor))) # not-unreserved characters remain quoted when unquoting to IRI @@ -661,7 +669,7 @@ def _codec_error_url_quote(e): """Used in :func:`uri_to_iri` after unquoting to re-quote any invalid bytes. """ - out = _fast_url_quote(e.object[e.start:e.end]) + out = _fast_url_quote(e.object[e.start : e.end]) if PY2: out = out.decode("utf-8") @@ -752,7 +760,7 @@ def iri_to_uri(iri, charset="utf-8", errors="strict", safe_conversion=False): # contains ASCII characters, return it unconverted. try: native_iri = to_native(iri) - ascii_iri = native_iri.encode('ascii') + ascii_iri = native_iri.encode("ascii") # Only return if it doesn't have whitespace. (Why?) if len(ascii_iri.split()) == 1: @@ -769,8 +777,15 @@ def iri_to_uri(iri, charset="utf-8", errors="strict", safe_conversion=False): ) -def url_decode(s, charset='utf-8', decode_keys=False, include_empty=True, - errors='replace', separator='&', cls=None): +def url_decode( + s, + charset="utf-8", + decode_keys=False, + include_empty=True, + errors="replace", + separator="&", + cls=None, +): """ Parse a querystring and return it as :class:`MultiDict`. There is a difference in key decoding on different Python versions. On Python 3 @@ -812,16 +827,27 @@ def url_decode(s, charset='utf-8', decode_keys=False, include_empty=True, if cls is None: cls = MultiDict if isinstance(s, text_type) and not isinstance(separator, text_type): - separator = separator.decode(charset or 'ascii') + separator = separator.decode(charset or "ascii") elif isinstance(s, bytes) and not isinstance(separator, bytes): - separator = separator.encode(charset or 'ascii') - return cls(_url_decode_impl(s.split(separator), charset, decode_keys, - include_empty, errors)) + separator = separator.encode(charset or "ascii") + return cls( + _url_decode_impl( + s.split(separator), charset, decode_keys, include_empty, errors + ) + ) -def url_decode_stream(stream, charset='utf-8', decode_keys=False, - include_empty=True, errors='replace', separator='&', - cls=None, limit=None, return_iterator=False): +def url_decode_stream( + stream, + charset="utf-8", + decode_keys=False, + include_empty=True, + errors="replace", + separator="&", + cls=None, + limit=None, + return_iterator=False, +): """Works like :func:`url_decode` but decodes a stream. The behavior of stream and limit follows functions like :func:`~werkzeug.wsgi.make_line_iter`. The generator of pairs is @@ -849,14 +875,18 @@ def url_decode_stream(stream, charset='utf-8', decode_keys=False, and an iterator over all decoded pairs is returned """ - from werkzeug.wsgi import make_chunk_iter + from .wsgi import make_chunk_iter + + pair_iter = make_chunk_iter(stream, separator, limit) + decoder = _url_decode_impl(pair_iter, charset, decode_keys, include_empty, errors) + if return_iterator: - cls = lambda x: x - elif cls is None: + return decoder + + if cls is None: cls = MultiDict - pair_iter = make_chunk_iter(stream, separator, limit) - return cls(_url_decode_impl(pair_iter, charset, decode_keys, - include_empty, errors)) + + return cls(decoder) def _url_decode_impl(pair_iter, charset, decode_keys, include_empty, errors): @@ -864,22 +894,23 @@ def _url_decode_impl(pair_iter, charset, decode_keys, include_empty, errors): if not pair: continue s = make_literal_wrapper(pair) - equal = s('=') + equal = s("=") if equal in pair: key, value = pair.split(equal, 1) else: if not include_empty: continue key = pair - value = s('') + value = s("") key = url_unquote_plus(key, charset, errors) if charset is not None and PY2 and not decode_keys: key = try_coerce_native(key) yield key, url_unquote_plus(value, charset, errors) -def url_encode(obj, charset='utf-8', encode_keys=False, sort=False, key=None, - separator=b'&'): +def url_encode( + obj, charset="utf-8", encode_keys=False, sort=False, key=None, separator=b"&" +): """URL encode a dict/`MultiDict`. If a value is `None` it will not appear in the result string. Per default only values are encoded into the target charset strings. If `encode_keys` is set to ``True`` unicode keys are @@ -900,12 +931,19 @@ def url_encode(obj, charset='utf-8', encode_keys=False, sort=False, key=None, :param key: an optional function to be used for sorting. For more details check out the :func:`sorted` documentation. """ - separator = to_native(separator, 'ascii') + separator = to_native(separator, "ascii") return separator.join(_url_encode_impl(obj, charset, encode_keys, sort, key)) -def url_encode_stream(obj, stream=None, charset='utf-8', encode_keys=False, - sort=False, key=None, separator=b'&'): +def url_encode_stream( + obj, + stream=None, + charset="utf-8", + encode_keys=False, + sort=False, + key=None, + separator=b"&", +): """Like :meth:`url_encode` but writes the results to a stream object. If the stream is `None` a generator over all encoded pairs is returned. @@ -924,7 +962,7 @@ def url_encode_stream(obj, stream=None, charset='utf-8', encode_keys=False, :param key: an optional function to be used for sorting. For more details check out the :func:`sorted` documentation. """ - separator = to_native(separator, 'ascii') + separator = to_native(separator, "ascii") gen = _url_encode_impl(obj, charset, encode_keys, sort, key) if stream is None: return gen @@ -955,55 +993,53 @@ def url_join(base, url, allow_fragments=True): if not url: return base - bscheme, bnetloc, bpath, bquery, bfragment = \ - url_parse(base, allow_fragments=allow_fragments) - scheme, netloc, path, query, fragment = \ - url_parse(url, bscheme, allow_fragments) + bscheme, bnetloc, bpath, bquery, bfragment = url_parse( + base, allow_fragments=allow_fragments + ) + scheme, netloc, path, query, fragment = url_parse(url, bscheme, allow_fragments) if scheme != bscheme: return url if netloc: return url_unparse((scheme, netloc, path, query, fragment)) netloc = bnetloc - if path[:1] == s('/'): - segments = path.split(s('/')) + if path[:1] == s("/"): + segments = path.split(s("/")) elif not path: - segments = bpath.split(s('/')) + segments = bpath.split(s("/")) if not query: query = bquery else: - segments = bpath.split(s('/'))[:-1] + path.split(s('/')) + segments = bpath.split(s("/"))[:-1] + path.split(s("/")) # If the rightmost part is "./" we want to keep the slash but # remove the dot. - if segments[-1] == s('.'): - segments[-1] = s('') + if segments[-1] == s("."): + segments[-1] = s("") # Resolve ".." and "." - segments = [segment for segment in segments if segment != s('.')] + segments = [segment for segment in segments if segment != s(".")] while 1: i = 1 n = len(segments) - 1 while i < n: - if segments[i] == s('..') and \ - segments[i - 1] not in (s(''), s('..')): - del segments[i - 1:i + 1] + if segments[i] == s("..") and segments[i - 1] not in (s(""), s("..")): + del segments[i - 1 : i + 1] break i += 1 else: break # Remove trailing ".." if the URL is absolute - unwanted_marker = [s(''), s('..')] + unwanted_marker = [s(""), s("..")] while segments[:2] == unwanted_marker: del segments[1] - path = s('/').join(segments) + path = s("/").join(segments) return url_unparse((scheme, netloc, path, query, fragment)) class Href(object): - """Implements a callable that constructs URLs with the given base. The function can be called with any number of positional and keyword arguments which than are used to assemble the URL. Works with URLs @@ -1054,39 +1090,45 @@ class Href(object): `sort` and `key` were added. """ - def __init__(self, base='./', charset='utf-8', sort=False, key=None): + def __init__(self, base="./", charset="utf-8", sort=False, key=None): if not base: - base = './' + base = "./" self.base = base self.charset = charset self.sort = sort self.key = key def __getattr__(self, name): - if name[:2] == '__': + if name[:2] == "__": raise AttributeError(name) base = self.base - if base[-1:] != '/': - base += '/' + if base[-1:] != "/": + base += "/" return Href(url_join(base, name), self.charset, self.sort, self.key) def __call__(self, *path, **query): if path and isinstance(path[-1], dict): if query: - raise TypeError('keyword arguments and query-dicts ' - 'can\'t be combined') + raise TypeError("keyword arguments and query-dicts can't be combined") query, path = path[-1], path[:-1] elif query: - query = dict([(k.endswith('_') and k[:-1] or k, v) - for k, v in query.items()]) - path = '/'.join([to_unicode(url_quote(x, self.charset), 'ascii') - for x in path if x is not None]).lstrip('/') + query = dict( + [(k.endswith("_") and k[:-1] or k, v) for k, v in query.items()] + ) + path = "/".join( + [ + to_unicode(url_quote(x, self.charset), "ascii") + for x in path + if x is not None + ] + ).lstrip("/") rv = self.base if path: - if not rv.endswith('/'): - rv += '/' - rv = url_join(rv, './' + path) + if not rv.endswith("/"): + rv += "/" + rv = url_join(rv, "./" + path) if query: - rv += '?' + to_unicode(url_encode(query, self.charset, sort=self.sort, - key=self.key), 'ascii') + rv += "?" + to_unicode( + url_encode(query, self.charset, sort=self.sort, key=self.key), "ascii" + ) return to_native(rv) diff --git a/src/werkzeug/useragents.py b/src/werkzeug/useragents.py index 90d2649b..e265e093 100644 --- a/src/werkzeug/useragents.py +++ b/src/werkzeug/useragents.py @@ -12,80 +12,82 @@ :license: BSD-3-Clause """ import re +import warnings class UserAgentParser(object): - """A simple user agent parser. Used by the `UserAgent`.""" platforms = ( - ('cros', 'chromeos'), - ('iphone|ios', 'iphone'), - ('ipad', 'ipad'), - (r'darwin|mac|os\s*x', 'macos'), - ('win', 'windows'), - (r'android', 'android'), - ('netbsd', 'netbsd'), - ('openbsd', 'openbsd'), - ('freebsd', 'freebsd'), - ('dragonfly', 'dragonflybsd'), - ('(sun|i86)os', 'solaris'), - (r'x11|lin(\b|ux)?', 'linux'), - (r'nintendo\s+wii', 'wii'), - ('irix', 'irix'), - ('hp-?ux', 'hpux'), - ('aix', 'aix'), - ('sco|unix_sv', 'sco'), - ('bsd', 'bsd'), - ('amiga', 'amiga'), - ('blackberry|playbook', 'blackberry'), - ('symbian', 'symbian') + ("cros", "chromeos"), + ("iphone|ios", "iphone"), + ("ipad", "ipad"), + (r"darwin|mac|os\s*x", "macos"), + ("win", "windows"), + (r"android", "android"), + ("netbsd", "netbsd"), + ("openbsd", "openbsd"), + ("freebsd", "freebsd"), + ("dragonfly", "dragonflybsd"), + ("(sun|i86)os", "solaris"), + (r"x11|lin(\b|ux)?", "linux"), + (r"nintendo\s+wii", "wii"), + ("irix", "irix"), + ("hp-?ux", "hpux"), + ("aix", "aix"), + ("sco|unix_sv", "sco"), + ("bsd", "bsd"), + ("amiga", "amiga"), + ("blackberry|playbook", "blackberry"), + ("symbian", "symbian"), ) browsers = ( - ('googlebot', 'google'), - ('msnbot', 'msn'), - ('yahoo', 'yahoo'), - ('ask jeeves', 'ask'), - (r'aol|america\s+online\s+browser', 'aol'), - ('opera', 'opera'), - ('edge', 'edge'), - ('chrome|crios', 'chrome'), - ('seamonkey', 'seamonkey'), - ('firefox|firebird|phoenix|iceweasel', 'firefox'), - ('galeon', 'galeon'), - ('safari|version', 'safari'), - ('webkit', 'webkit'), - ('camino', 'camino'), - ('konqueror', 'konqueror'), - ('k-meleon', 'kmeleon'), - ('netscape', 'netscape'), - (r'msie|microsoft\s+internet\s+explorer|trident/.+? rv:', 'msie'), - ('lynx', 'lynx'), - ('links', 'links'), - ('Baiduspider', 'baidu'), - ('bingbot', 'bing'), - ('mozilla', 'mozilla') + ("googlebot", "google"), + ("msnbot", "msn"), + ("yahoo", "yahoo"), + ("ask jeeves", "ask"), + (r"aol|america\s+online\s+browser", "aol"), + ("opera", "opera"), + ("edge", "edge"), + ("chrome|crios", "chrome"), + ("seamonkey", "seamonkey"), + ("firefox|firebird|phoenix|iceweasel", "firefox"), + ("galeon", "galeon"), + ("safari|version", "safari"), + ("webkit", "webkit"), + ("camino", "camino"), + ("konqueror", "konqueror"), + ("k-meleon", "kmeleon"), + ("netscape", "netscape"), + (r"msie|microsoft\s+internet\s+explorer|trident/.+? rv:", "msie"), + ("lynx", "lynx"), + ("links", "links"), + ("Baiduspider", "baidu"), + ("bingbot", "bing"), + ("mozilla", "mozilla"), ) - _browser_version_re = r'(?:%s)[/\sa-z(]*(\d+[.\da-z]+)?' + _browser_version_re = r"(?:%s)[/\sa-z(]*(\d+[.\da-z]+)?" _language_re = re.compile( - r'(?:;\s*|\s+)(\b\w{2}\b(?:-\b\w{2}\b)?)\s*;|' - r'(?:\(|\[|;)\s*(\b\w{2}\b(?:-\b\w{2}\b)?)\s*(?:\]|\)|;)' + r"(?:;\s*|\s+)(\b\w{2}\b(?:-\b\w{2}\b)?)\s*;|" + r"(?:\(|\[|;)\s*(\b\w{2}\b(?:-\b\w{2}\b)?)\s*(?:\]|\)|;)" ) def __init__(self): self.platforms = [(b, re.compile(a, re.I)) for a, b in self.platforms] - self.browsers = [(b, re.compile(self._browser_version_re % a, re.I)) - for a, b in self.browsers] + self.browsers = [ + (b, re.compile(self._browser_version_re % a, re.I)) + for a, b in self.browsers + ] def __call__(self, user_agent): - for platform, regex in self.platforms: + for platform, regex in self.platforms: # noqa: B007 match = regex.search(user_agent) if match is not None: break else: platform = None - for browser, regex in self.browsers: + for browser, regex in self.browsers: # noqa: B007 match = regex.search(user_agent) if match is not None: version = match.group(1) @@ -101,7 +103,6 @@ class UserAgentParser(object): class UserAgent(object): - """Represents a user agent. Pass it a WSGI environment or a user agent string and you can inspect some of the details from the user agent string via the attributes. The following attributes exist: @@ -181,10 +182,11 @@ class UserAgent(object): def __init__(self, environ_or_string): if isinstance(environ_or_string, dict): - environ_or_string = environ_or_string.get('HTTP_USER_AGENT', '') + environ_or_string = environ_or_string.get("HTTP_USER_AGENT", "") self.string = environ_or_string - self.platform, self.browser, self.version, self.language = \ - self._parser(environ_or_string) + self.platform, self.browser, self.version, self.language = self._parser( + environ_or_string + ) def to_header(self): return self.string @@ -198,16 +200,21 @@ class UserAgent(object): __bool__ = __nonzero__ def __repr__(self): - return '<%s %r/%s>' % ( - self.__class__.__name__, - self.browser, - self.version - ) + return "<%s %r/%s>" % (self.__class__.__name__, self.browser, self.version) -# conceptionally this belongs in this module but because we want to lazily -# load the user agent module (which happens in wrappers.py) we have to import -# it afterwards. The class itself has the module set to this module so -# pickle, inspect and similar modules treat the object as if it was really -# implemented here. -from werkzeug.wrappers import UserAgentMixin # noqa +# DEPRECATED +from .wrappers import UserAgentMixin as _UserAgentMixin + + +class UserAgentMixin(_UserAgentMixin): + @property + def user_agent(self, *args, **kwargs): + warnings.warn( + "'werkzeug.useragents.UserAgentMixin' should be imported" + " from 'werkzeug.wrappers.UserAgentMixin'. This old import" + " will be removed in version 1.0.", + DeprecationWarning, + stacklevel=2, + ) + return super(_UserAgentMixin, self).user_agent diff --git a/src/werkzeug/utils.py b/src/werkzeug/utils.py index 406f62b3..20620572 100644 --- a/src/werkzeug/utils.py +++ b/src/werkzeug/utils.py @@ -11,32 +11,47 @@ :license: BSD-3-Clause """ import codecs -import re import os -import sys import pkgutil +import re +import sys import warnings +from ._compat import iteritems +from ._compat import PY2 +from ._compat import reraise +from ._compat import string_types +from ._compat import text_type +from ._compat import unichr +from ._internal import _DictAccessorProperty +from ._internal import _missing +from ._internal import _parse_signature + try: from html.entities import name2codepoint except ImportError: from htmlentitydefs import name2codepoint -from werkzeug._compat import unichr, text_type, string_types, iteritems, \ - reraise, PY2 -from werkzeug._internal import _DictAccessorProperty, \ - _parse_signature, _missing - -_format_re = re.compile(r'\$(?:(%s)|\{(%s)\})' % (('[a-zA-Z_][a-zA-Z0-9_]*',) * 2)) -_entity_re = re.compile(r'&([^;]+);') -_filename_ascii_strip_re = re.compile(r'[^A-Za-z0-9_.-]') -_windows_device_files = ('CON', 'AUX', 'COM1', 'COM2', 'COM3', 'COM4', 'LPT1', - 'LPT2', 'LPT3', 'PRN', 'NUL') +_format_re = re.compile(r"\$(?:(%s)|\{(%s)\})" % (("[a-zA-Z_][a-zA-Z0-9_]*",) * 2)) +_entity_re = re.compile(r"&([^;]+);") +_filename_ascii_strip_re = re.compile(r"[^A-Za-z0-9_.-]") +_windows_device_files = ( + "CON", + "AUX", + "COM1", + "COM2", + "COM3", + "COM4", + "LPT1", + "LPT2", + "LPT3", + "PRN", + "NUL", +) class cached_property(property): - """A decorator that converts a function into a lazy property. The function wrapped is called the first time to retrieve the result and then that calculated result is used the next time you access @@ -79,7 +94,6 @@ class cached_property(property): class environ_property(_DictAccessorProperty): - """Maps request attributes to environment variables. This works not only for the Werzeug request object, but also any other class with an environ attribute: @@ -107,7 +121,6 @@ class environ_property(_DictAccessorProperty): class header_property(_DictAccessorProperty): - """Like `environ_property` but for headers.""" def lookup(self, obj): @@ -115,7 +128,6 @@ class header_property(_DictAccessorProperty): class HTMLBuilder(object): - """Helper object for HTML generation. Per default there are two instances of that class. The `html` one, and @@ -141,20 +153,45 @@ class HTMLBuilder(object): u'<p><foo></p>' """ - _entity_re = re.compile(r'&([^;]+);') + _entity_re = re.compile(r"&([^;]+);") _entities = name2codepoint.copy() - _entities['apos'] = 39 - _empty_elements = set([ - 'area', 'base', 'basefont', 'br', 'col', 'command', 'embed', 'frame', - 'hr', 'img', 'input', 'keygen', 'isindex', 'link', 'meta', 'param', - 'source', 'wbr' - ]) - _boolean_attributes = set([ - 'selected', 'checked', 'compact', 'declare', 'defer', 'disabled', - 'ismap', 'multiple', 'nohref', 'noresize', 'noshade', 'nowrap' - ]) - _plaintext_elements = set(['textarea']) - _c_like_cdata = set(['script', 'style']) + _entities["apos"] = 39 + _empty_elements = { + "area", + "base", + "basefont", + "br", + "col", + "command", + "embed", + "frame", + "hr", + "img", + "input", + "keygen", + "isindex", + "link", + "meta", + "param", + "source", + "wbr", + } + _boolean_attributes = { + "selected", + "checked", + "compact", + "declare", + "defer", + "disabled", + "ismap", + "multiple", + "nohref", + "noresize", + "noshade", + "nowrap", + } + _plaintext_elements = {"textarea"} + _c_like_cdata = {"script", "style"} def __init__(self, dialect): self._dialect = dialect @@ -163,56 +200,56 @@ class HTMLBuilder(object): return escape(s) def __getattr__(self, tag): - if tag[:2] == '__': + if tag[:2] == "__": raise AttributeError(tag) def proxy(*children, **arguments): - buffer = '<' + tag + buffer = "<" + tag for key, value in iteritems(arguments): if value is None: continue - if key[-1] == '_': + if key[-1] == "_": key = key[:-1] if key in self._boolean_attributes: if not value: continue - if self._dialect == 'xhtml': + if self._dialect == "xhtml": value = '="' + key + '"' else: - value = '' + value = "" else: value = '="' + escape(value) + '"' - buffer += ' ' + key + value + buffer += " " + key + value if not children and tag in self._empty_elements: - if self._dialect == 'xhtml': - buffer += ' />' + if self._dialect == "xhtml": + buffer += " />" else: - buffer += '>' + buffer += ">" return buffer - buffer += '>' + buffer += ">" - children_as_string = ''.join([text_type(x) for x in children - if x is not None]) + children_as_string = "".join( + [text_type(x) for x in children if x is not None] + ) if children_as_string: if tag in self._plaintext_elements: children_as_string = escape(children_as_string) - elif tag in self._c_like_cdata and self._dialect == 'xhtml': - children_as_string = '/*<![CDATA[*/' + \ - children_as_string + '/*]]>*/' - buffer += children_as_string + '</' + tag + '>' + elif tag in self._c_like_cdata and self._dialect == "xhtml": + children_as_string = ( + "/*<![CDATA[*/" + children_as_string + "/*]]>*/" + ) + buffer += children_as_string + "</" + tag + ">" return buffer + return proxy def __repr__(self): - return '<%s for %r>' % ( - self.__class__.__name__, - self._dialect - ) + return "<%s for %r>" % (self.__class__.__name__, self._dialect) -html = HTMLBuilder('html') -xhtml = HTMLBuilder('xhtml') +html = HTMLBuilder("html") +xhtml = HTMLBuilder("xhtml") # https://cgit.freedesktop.org/xdg/shared-mime-info/tree/freedesktop.org.xml.in # https://www.iana.org/assignments/media-types/media-types.xhtml @@ -243,11 +280,11 @@ def get_content_type(mimetype, charset): ``application/javascript`` are also given charsets. """ if ( - mimetype.startswith('text/') + mimetype.startswith("text/") or mimetype in _charset_mimetypes - or mimetype.endswith('+xml') + or mimetype.endswith("+xml") ): - mimetype += '; charset=' + charset + mimetype += "; charset=" + charset return mimetype @@ -294,7 +331,7 @@ def detect_utf_encoding(data): return "utf-16-le" if len(head) == 2: - return "utf-16-be" if head.startswith(b'\x00') else "utf-16-le" + return "utf-16-be" if head.startswith(b"\x00") else "utf-16-le" return "utf-8" @@ -311,11 +348,13 @@ def format_string(string, context): :param string: the format string. :param context: a dict with the variables to insert. """ + def lookup_arg(match): x = context[match.group(1) or match.group(2)] if not isinstance(x, string_types): x = type(string)(x) return x + return _format_re.sub(lookup_arg, string) @@ -345,21 +384,26 @@ def secure_filename(filename): """ if isinstance(filename, text_type): from unicodedata import normalize - filename = normalize('NFKD', filename).encode('ascii', 'ignore') + + filename = normalize("NFKD", filename).encode("ascii", "ignore") if not PY2: - filename = filename.decode('ascii') + filename = filename.decode("ascii") for sep in os.path.sep, os.path.altsep: if sep: - filename = filename.replace(sep, ' ') - filename = str(_filename_ascii_strip_re.sub('', '_'.join( - filename.split()))).strip('._') + filename = filename.replace(sep, " ") + filename = str(_filename_ascii_strip_re.sub("", "_".join(filename.split()))).strip( + "._" + ) # on nt a couple of special files are present in each folder. We # have to ensure that the target file is not such a filename. In # this case we prepend an underline - if os.name == 'nt' and filename and \ - filename.split('.')[0].upper() in _windows_device_files: - filename = '_' + filename + if ( + os.name == "nt" + and filename + and filename.split(".")[0].upper() in _windows_device_files + ): + filename = "_" + filename return filename @@ -376,21 +420,26 @@ def escape(s, quote=None): :param quote: ignored. """ if s is None: - return '' - elif hasattr(s, '__html__'): + return "" + elif hasattr(s, "__html__"): return text_type(s.__html__()) elif not isinstance(s, string_types): s = text_type(s) if quote is not None: from warnings import warn + warn( "The 'quote' parameter is no longer used as of version 0.9" " and will be removed in version 1.0.", DeprecationWarning, stacklevel=2, ) - s = s.replace('&', '&').replace('<', '<') \ - .replace('>', '>').replace('"', """) + s = ( + s.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + ) return s @@ -400,18 +449,20 @@ def unescape(s): :param s: the string to unescape. """ + def handle_match(m): name = m.group(1) if name in HTMLBuilder._entities: return unichr(HTMLBuilder._entities[name]) try: - if name[:2] in ('#x', '#X'): + if name[:2] in ("#x", "#X"): return unichr(int(name[2:], 16)) - elif name.startswith('#'): + elif name.startswith("#"): return unichr(int(name[1:])) except ValueError: pass - return u'' + return u"" + return _entity_re.sub(handle_match, s) @@ -436,22 +487,26 @@ def redirect(location, code=302, Response=None): unspecified. """ if Response is None: - from werkzeug.wrappers import Response + from .wrappers import Response display_location = escape(location) if isinstance(location, text_type): # Safe conversion is necessary here as we might redirect # to a broken URI scheme (for instance itms-services). - from werkzeug.urls import iri_to_uri + from .urls import iri_to_uri + location = iri_to_uri(location, safe_conversion=True) response = Response( '<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 3.2 Final//EN">\n' - '<title>Redirecting...</title>\n' - '<h1>Redirecting...</h1>\n' - '<p>You should be redirected automatically to target URL: ' - '<a href="%s">%s</a>. If not click the link.' % - (escape(location), display_location), code, mimetype='text/html') - response.headers['Location'] = location + "<title>Redirecting...</title>\n" + "<h1>Redirecting...</h1>\n" + "<p>You should be redirected automatically to target URL: " + '<a href="%s">%s</a>. If not click the link.' + % (escape(location), display_location), + code, + mimetype="text/html", + ) + response.headers["Location"] = location return response @@ -463,10 +518,10 @@ def append_slash_redirect(environ, code=301): the redirect. :param code: the status code for the redirect. """ - new_path = environ['PATH_INFO'].strip('/') + '/' - query_string = environ.get('QUERY_STRING') + new_path = environ["PATH_INFO"].strip("/") + "/" + query_string = environ.get("QUERY_STRING") if query_string: - new_path += '?' + query_string + new_path += "?" + query_string return redirect(new_path, code) @@ -486,17 +541,17 @@ def import_string(import_name, silent=False): # force the import name to automatically convert to strings # __import__ is not able to handle unicode strings in the fromlist # if the module is a package - import_name = str(import_name).replace(':', '.') + import_name = str(import_name).replace(":", ".") try: try: __import__(import_name) except ImportError: - if '.' not in import_name: + if "." not in import_name: raise else: return sys.modules[import_name] - module_name, obj_name = import_name.rsplit('.', 1) + module_name, obj_name = import_name.rsplit(".", 1) module = __import__(module_name, globals(), locals(), [obj_name]) try: return getattr(module, obj_name) @@ -506,9 +561,8 @@ def import_string(import_name, silent=False): except ImportError as e: if not silent: reraise( - ImportStringError, - ImportStringError(import_name, e), - sys.exc_info()[2]) + ImportStringError, ImportStringError(import_name, e), sys.exc_info()[2] + ) def find_modules(import_path, include_packages=False, recursive=False): @@ -527,11 +581,11 @@ def find_modules(import_path, include_packages=False, recursive=False): :return: generator """ module = import_string(import_path) - path = getattr(module, '__path__', None) + path = getattr(module, "__path__", None) if path is None: - raise ValueError('%r is not a package' % import_path) - basename = module.__name__ + '.' - for importer, modname, ispkg in pkgutil.iter_modules(path): + raise ValueError("%r is not a package" % import_path) + basename = module.__name__ + "." + for _importer, modname, ispkg in pkgutil.iter_modules(path): modname = basename + modname if ispkg: if include_packages: @@ -608,24 +662,32 @@ def bind_arguments(func, args, kwargs): :param kwargs: a dict of keyword arguments. :return: a :class:`dict` of bound keyword arguments. """ - args, kwargs, missing, extra, extra_positional, \ - arg_spec, vararg_var, kwarg_var = _parse_signature(func)(args, kwargs) + ( + args, + kwargs, + missing, + extra, + extra_positional, + arg_spec, + vararg_var, + kwarg_var, + ) = _parse_signature(func)(args, kwargs) values = {} - for (name, has_default, default), value in zip(arg_spec, args): + for (name, _has_default, _default), value in zip(arg_spec, args): values[name] = value if vararg_var is not None: values[vararg_var] = tuple(extra_positional) elif extra_positional: - raise TypeError('too many positional arguments') + raise TypeError("too many positional arguments") if kwarg_var is not None: multikw = set(extra) & set([x[0] for x in arg_spec]) if multikw: - raise TypeError('got multiple values for keyword argument ' - + repr(next(iter(multikw)))) + raise TypeError( + "got multiple values for keyword argument " + repr(next(iter(multikw))) + ) values[kwarg_var] = extra elif extra: - raise TypeError('got unexpected keyword argument ' - + repr(next(iter(extra)))) + raise TypeError("got unexpected keyword argument " + repr(next(iter(extra)))) return values @@ -637,15 +699,14 @@ class ArgumentValidationError(ValueError): self.missing = set(missing or ()) self.extra = extra or {} self.extra_positional = extra_positional or [] - ValueError.__init__(self, 'function arguments invalid. (' - '%d missing, %d additional)' % ( - len(self.missing), - len(self.extra) + len(self.extra_positional) - )) + ValueError.__init__( + self, + "function arguments invalid. (%d missing, %d additional)" + % (len(self.missing), len(self.extra) + len(self.extra_positional)), + ) class ImportStringError(ImportError): - """Provides information about a failed :func:`import_string` attempt.""" #: String in dotted notation that failed to be imported. @@ -658,47 +719,51 @@ class ImportStringError(ImportError): self.exception = exception msg = ( - 'import_string() failed for %r. Possible reasons are:\n\n' - '- missing __init__.py in a package;\n' - '- package or module path not included in sys.path;\n' - '- duplicated package or module name taking precedence in ' - 'sys.path;\n' - '- missing module, class, function or variable;\n\n' - 'Debugged import:\n\n%s\n\n' - 'Original exception:\n\n%s: %s') - - name = '' + "import_string() failed for %r. Possible reasons are:\n\n" + "- missing __init__.py in a package;\n" + "- package or module path not included in sys.path;\n" + "- duplicated package or module name taking precedence in " + "sys.path;\n" + "- missing module, class, function or variable;\n\n" + "Debugged import:\n\n%s\n\n" + "Original exception:\n\n%s: %s" + ) + + name = "" tracked = [] - for part in import_name.replace(':', '.').split('.'): - name += (name and '.') + part + for part in import_name.replace(":", ".").split("."): + name += (name and ".") + part imported = import_string(name, silent=True) if imported: - tracked.append((name, getattr(imported, '__file__', None))) + tracked.append((name, getattr(imported, "__file__", None))) else: - track = ['- %r found in %r.' % (n, i) for n, i in tracked] - track.append('- %r not found.' % name) - msg = msg % (import_name, '\n'.join(track), - exception.__class__.__name__, str(exception)) + track = ["- %r found in %r." % (n, i) for n, i in tracked] + track.append("- %r not found." % name) + msg = msg % ( + import_name, + "\n".join(track), + exception.__class__.__name__, + str(exception), + ) break ImportError.__init__(self, msg) def __repr__(self): - return '<%s(%r, %r)>' % (self.__class__.__name__, self.import_name, - self.exception) + return "<%s(%r, %r)>" % ( + self.__class__.__name__, + self.import_name, + self.exception, + ) # DEPRECATED -from werkzeug.datastructures import ( - MultiDict as _MultiDict, - CombinedMultiDict as _CombinedMultiDict, - Headers as _Headers, - EnvironHeaders as _EnvironHeaders, -) -from werkzeug.http import ( - parse_cookie as _parse_cookie, - dump_cookie as _dump_cookie, -) +from .datastructures import CombinedMultiDict as _CombinedMultiDict +from .datastructures import EnvironHeaders as _EnvironHeaders +from .datastructures import Headers as _Headers +from .datastructures import MultiDict as _MultiDict +from .http import dump_cookie as _dump_cookie +from .http import parse_cookie as _parse_cookie class MultiDict(_MultiDict): diff --git a/src/werkzeug/wrappers/accept.py b/src/werkzeug/wrappers/accept.py index 9fe50141..d0620a0a 100644 --- a/src/werkzeug/wrappers/accept.py +++ b/src/werkzeug/wrappers/accept.py @@ -17,15 +17,16 @@ class AcceptMixin(object): """List of mimetypes this client supports as :class:`~werkzeug.datastructures.MIMEAccept` object. """ - return parse_accept_header(self.environ.get('HTTP_ACCEPT'), MIMEAccept) + return parse_accept_header(self.environ.get("HTTP_ACCEPT"), MIMEAccept) @cached_property def accept_charsets(self): """List of charsets this client supports as :class:`~werkzeug.datastructures.CharsetAccept` object. """ - return parse_accept_header(self.environ.get('HTTP_ACCEPT_CHARSET'), - CharsetAccept) + return parse_accept_header( + self.environ.get("HTTP_ACCEPT_CHARSET"), CharsetAccept + ) @cached_property def accept_encodings(self): @@ -33,7 +34,7 @@ class AcceptMixin(object): are compression encodings such as gzip. For charsets have a look at :attr:`accept_charset`. """ - return parse_accept_header(self.environ.get('HTTP_ACCEPT_ENCODING')) + return parse_accept_header(self.environ.get("HTTP_ACCEPT_ENCODING")) @cached_property def accept_languages(self): @@ -44,5 +45,6 @@ class AcceptMixin(object): In previous versions this was a regular :class:`~werkzeug.datastructures.Accept` object. """ - return parse_accept_header(self.environ.get('HTTP_ACCEPT_LANGUAGE'), - LanguageAccept) + return parse_accept_header( + self.environ.get("HTTP_ACCEPT_LANGUAGE"), LanguageAccept + ) diff --git a/src/werkzeug/wrappers/auth.py b/src/werkzeug/wrappers/auth.py index 2660e9ba..714f7554 100644 --- a/src/werkzeug/wrappers/auth.py +++ b/src/werkzeug/wrappers/auth.py @@ -12,7 +12,7 @@ class AuthorizationMixin(object): @cached_property def authorization(self): """The `Authorization` object in parsed form.""" - header = self.environ.get('HTTP_AUTHORIZATION') + header = self.environ.get("HTTP_AUTHORIZATION") return parse_authorization_header(header) @@ -22,10 +22,12 @@ class WWWAuthenticateMixin(object): @property def www_authenticate(self): """The `WWW-Authenticate` header in a parsed form.""" + def on_update(www_auth): - if not www_auth and 'www-authenticate' in self.headers: - del self.headers['www-authenticate'] + if not www_auth and "www-authenticate" in self.headers: + del self.headers["www-authenticate"] elif www_auth: - self.headers['WWW-Authenticate'] = www_auth.to_header() - header = self.headers.get('www-authenticate') + self.headers["WWW-Authenticate"] = www_auth.to_header() + + header = self.headers.get("www-authenticate") return parse_www_authenticate_header(header, on_update) diff --git a/src/werkzeug/wrappers/base_request.py b/src/werkzeug/wrappers/base_request.py index c106c9c4..f5c40be9 100644 --- a/src/werkzeug/wrappers/base_request.py +++ b/src/werkzeug/wrappers/base_request.py @@ -72,10 +72,10 @@ class BaseRequest(object): """ #: the charset for the request, defaults to utf-8 - charset = 'utf-8' + charset = "utf-8" #: the error handling procedure for errors, defaults to 'replace' - encoding_errors = 'replace' + encoding_errors = "replace" #: the maximum content length. This is forwarded to the form data #: parsing function (:func:`parse_form_data`). When set and the @@ -150,7 +150,7 @@ class BaseRequest(object): def __init__(self, environ, populate_request=True, shallow=False): self.environ = environ if populate_request and not shallow: - self.environ['werkzeug.request'] = self + self.environ["werkzeug.request"] = self self.shallow = shallow def __repr__(self): @@ -160,14 +160,11 @@ class BaseRequest(object): args = [] try: args.append("'%s'" % to_native(self.url, self.url_charset)) - args.append('[%s]' % self.method) + args.append("[%s]" % self.method) except Exception: - args.append('(invalid WSGI environ)') + args.append("(invalid WSGI environ)") - return '<%s %s>' % ( - self.__class__.__name__, - ' '.join(args) - ) + return "<%s %s>" % (self.__class__.__name__, " ".join(args)) @property def url_charset(self): @@ -197,9 +194,10 @@ class BaseRequest(object): :return: request object """ - from werkzeug.test import EnvironBuilder - charset = kwargs.pop('charset', cls.charset) - kwargs['charset'] = charset + from ..test import EnvironBuilder + + charset = kwargs.pop("charset", cls.charset) + kwargs["charset"] = charset builder = EnvironBuilder(*args, **kwargs) try: return builder.get_request(cls) @@ -228,7 +226,7 @@ class BaseRequest(object): #: the request. The return value is then called with the latest #: two arguments. This makes it possible to use this decorator for #: both methods and standalone WSGI functions. - from werkzeug.exceptions import HTTPException + from ..exceptions import HTTPException def application(*args): request = cls(args[-2]) @@ -241,8 +239,9 @@ class BaseRequest(object): return update_wrapper(application, f) - def _get_file_stream(self, total_content_length, content_type, filename=None, - content_length=None): + def _get_file_stream( + self, total_content_length, content_type, filename=None, content_length=None + ): """Called to get a stream for the file upload. This must provide a file-like class with `read()`, `readline()` @@ -266,7 +265,8 @@ class BaseRequest(object): total_content_length=total_content_length, content_type=content_type, filename=filename, - content_length=content_length) + content_length=content_length, + ) @property def want_form_data_parsed(self): @@ -275,7 +275,7 @@ class BaseRequest(object): .. versionadded:: 0.8 """ - return bool(self.environ.get('CONTENT_TYPE')) + return bool(self.environ.get("CONTENT_TYPE")) def make_form_data_parser(self): """Creates the form data parser. Instantiates the @@ -283,12 +283,14 @@ class BaseRequest(object): .. versionadded:: 0.8 """ - return self.form_data_parser_class(self._get_file_stream, - self.charset, - self.encoding_errors, - self.max_form_memory_size, - self.max_content_length, - self.parameter_storage_class) + return self.form_data_parser_class( + self._get_file_stream, + self.charset, + self.encoding_errors, + self.max_form_memory_size, + self.max_content_length, + self.parameter_storage_class, + ) def _load_form_data(self): """Method used internally to retrieve submitted data. After calling @@ -300,26 +302,30 @@ class BaseRequest(object): .. versionadded:: 0.8 """ # abort early if we have already consumed the stream - if 'form' in self.__dict__: + if "form" in self.__dict__: return _assert_not_shallow(self) if self.want_form_data_parsed: - content_type = self.environ.get('CONTENT_TYPE', '') + content_type = self.environ.get("CONTENT_TYPE", "") content_length = get_content_length(self.environ) mimetype, options = parse_options_header(content_type) parser = self.make_form_data_parser() - data = parser.parse(self._get_stream_for_parsing(), - mimetype, content_length, options) + data = parser.parse( + self._get_stream_for_parsing(), mimetype, content_length, options + ) else: - data = (self.stream, self.parameter_storage_class(), - self.parameter_storage_class()) + data = ( + self.stream, + self.parameter_storage_class(), + self.parameter_storage_class(), + ) # inject the values into the instance dict so that we bypass # our cached_property non-data descriptor. d = self.__dict__ - d['stream'], d['form'], d['files'] = data + d["stream"], d["form"], d["files"] = data def _get_stream_for_parsing(self): """This is the same as accessing :attr:`stream` with the difference @@ -328,7 +334,7 @@ class BaseRequest(object): .. versionadded:: 0.9.3 """ - cached_data = getattr(self, '_cached_data', None) + cached_data = getattr(self, "_cached_data", None) if cached_data is not None: return BytesIO(cached_data) return self.stream @@ -340,8 +346,8 @@ class BaseRequest(object): .. versionadded:: 0.9 """ - files = self.__dict__.get('files') - for key, value in iter_multi_items(files or ()): + files = self.__dict__.get("files") + for _key, value in iter_multi_items(files or ()): value.close() def __enter__(self): @@ -371,12 +377,14 @@ class BaseRequest(object): _assert_not_shallow(self) return get_input_stream(self.environ) - input_stream = environ_property('wsgi.input', """ - The WSGI input stream. + input_stream = environ_property( + "wsgi.input", + """The WSGI input stream. - In general it's a bad idea to use this one because you can easily read past - the boundary. Use the :attr:`stream` instead. - """) + In general it's a bad idea to use this one because you can + easily read past the boundary. Use the :attr:`stream` + instead.""", + ) @cached_property def args(self): @@ -389,9 +397,12 @@ class BaseRequest(object): :attr:`parameter_storage_class` to a different type. This might be necessary if the order of the form data is important. """ - return url_decode(wsgi_get_bytes(self.environ.get('QUERY_STRING', '')), - self.url_charset, errors=self.encoding_errors, - cls=self.parameter_storage_class) + return url_decode( + wsgi_get_bytes(self.environ.get("QUERY_STRING", "")), + self.url_charset, + errors=self.encoding_errors, + cls=self.parameter_storage_class, + ) @cached_property def data(self): @@ -401,7 +412,7 @@ class BaseRequest(object): """ if self.disable_data_descriptor: - raise AttributeError('data descriptor is disabled') + raise AttributeError("data descriptor is disabled") # XXX: this should eventually be deprecated. # We trigger form data parsing first which means that the descriptor @@ -436,7 +447,7 @@ class BaseRequest(object): .. versionadded:: 0.9 """ - rv = getattr(self, '_cached_data', None) + rv = getattr(self, "_cached_data", None) if rv is None: if parse_form_data: self._load_form_data() @@ -504,9 +515,12 @@ class BaseRequest(object): def cookies(self): """A :class:`dict` with the contents of all cookies transmitted with the request.""" - return parse_cookie(self.environ, self.charset, - self.encoding_errors, - cls=self.dict_storage_class) + return parse_cookie( + self.environ, + self.charset, + self.encoding_errors, + cls=self.dict_storage_class, + ) @cached_property def headers(self): @@ -521,37 +535,39 @@ class BaseRequest(object): info in the WSGI environment but will always include a leading slash, even if the URL root is accessed. """ - raw_path = wsgi_decoding_dance(self.environ.get('PATH_INFO') or '', - self.charset, self.encoding_errors) - return '/' + raw_path.lstrip('/') + raw_path = wsgi_decoding_dance( + self.environ.get("PATH_INFO") or "", self.charset, self.encoding_errors + ) + return "/" + raw_path.lstrip("/") @cached_property def full_path(self): """Requested path as unicode, including the query string.""" - return self.path + u'?' + to_unicode(self.query_string, self.url_charset) + return self.path + u"?" + to_unicode(self.query_string, self.url_charset) @cached_property def script_root(self): """The root path of the script without the trailing slash.""" - raw_path = wsgi_decoding_dance(self.environ.get('SCRIPT_NAME') or '', - self.charset, self.encoding_errors) - return raw_path.rstrip('/') + raw_path = wsgi_decoding_dance( + self.environ.get("SCRIPT_NAME") or "", self.charset, self.encoding_errors + ) + return raw_path.rstrip("/") @cached_property def url(self): """The reconstructed current URL as IRI. See also: :attr:`trusted_hosts`. """ - return get_current_url(self.environ, - trusted_hosts=self.trusted_hosts) + return get_current_url(self.environ, trusted_hosts=self.trusted_hosts) @cached_property def base_url(self): """Like :attr:`url` but without the querystring See also: :attr:`trusted_hosts`. """ - return get_current_url(self.environ, strip_querystring=True, - trusted_hosts=self.trusted_hosts) + return get_current_url( + self.environ, strip_querystring=True, trusted_hosts=self.trusted_hosts + ) @cached_property def url_root(self): @@ -559,16 +575,16 @@ class BaseRequest(object): root as IRI. See also: :attr:`trusted_hosts`. """ - return get_current_url(self.environ, True, - trusted_hosts=self.trusted_hosts) + return get_current_url(self.environ, True, trusted_hosts=self.trusted_hosts) @cached_property def host_url(self): """Just the host with scheme as IRI. See also: :attr:`trusted_hosts`. """ - return get_current_url(self.environ, host_only=True, - trusted_hosts=self.trusted_hosts) + return get_current_url( + self.environ, host_only=True, trusted_hosts=self.trusted_hosts + ) @cached_property def host(self): @@ -578,39 +594,51 @@ class BaseRequest(object): return get_host(self.environ, trusted_hosts=self.trusted_hosts) query_string = environ_property( - 'QUERY_STRING', '', read_only=True, - load_func=wsgi_get_bytes, doc='The URL parameters as raw bytestring.') + "QUERY_STRING", + "", + read_only=True, + load_func=wsgi_get_bytes, + doc="The URL parameters as raw bytestring.", + ) method = environ_property( - 'REQUEST_METHOD', 'GET', read_only=True, + "REQUEST_METHOD", + "GET", + read_only=True, load_func=lambda x: x.upper(), - doc="The request method. (For example ``'GET'`` or ``'POST'``).") + doc="The request method. (For example ``'GET'`` or ``'POST'``).", + ) @cached_property def access_route(self): """If a forwarded header exists this is a list of all ip addresses from the client ip to the last proxy server. """ - if 'HTTP_X_FORWARDED_FOR' in self.environ: - addr = self.environ['HTTP_X_FORWARDED_FOR'].split(',') + if "HTTP_X_FORWARDED_FOR" in self.environ: + addr = self.environ["HTTP_X_FORWARDED_FOR"].split(",") return self.list_storage_class([x.strip() for x in addr]) - elif 'REMOTE_ADDR' in self.environ: - return self.list_storage_class([self.environ['REMOTE_ADDR']]) + elif "REMOTE_ADDR" in self.environ: + return self.list_storage_class([self.environ["REMOTE_ADDR"]]) return self.list_storage_class() @property def remote_addr(self): """The remote address of the client.""" - return self.environ.get('REMOTE_ADDR') - - remote_user = environ_property('REMOTE_USER', doc=''' - If the server supports user authentication, and the script is - protected, this attribute contains the username the user has - authenticated as.''') - - scheme = environ_property('wsgi.url_scheme', doc=''' + return self.environ.get("REMOTE_ADDR") + + remote_user = environ_property( + "REMOTE_USER", + doc="""If the server supports user authentication, and the + script is protected, this attribute contains the username the + user has authenticated as.""", + ) + + scheme = environ_property( + "wsgi.url_scheme", + doc=""" URL scheme (http or https). - .. versionadded:: 0.7''') + .. versionadded:: 0.7""", + ) @property def is_xhr(self): @@ -630,28 +658,36 @@ class BaseRequest(object): " is not standard and is unreliable. You may be able to use" " 'accept_mimetypes' instead.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) - return self.environ.get( - 'HTTP_X_REQUESTED_WITH', '' - ).lower() == 'xmlhttprequest' - - is_secure = property(lambda self: self.environ['wsgi.url_scheme'] == 'https', - doc='`True` if the request is secure.') - is_multithread = environ_property('wsgi.multithread', doc=''' - boolean that is `True` if the application is served by - a multithreaded WSGI server.''') - is_multiprocess = environ_property('wsgi.multiprocess', doc=''' - boolean that is `True` if the application is served by - a WSGI server that spawns multiple processes.''') - is_run_once = environ_property('wsgi.run_once', doc=''' - boolean that is `True` if the application will be executed only - once in a process lifetime. This is the case for CGI for example, - but it's not guaranteed that the execution only happens one time.''') + return self.environ.get("HTTP_X_REQUESTED_WITH", "").lower() == "xmlhttprequest" + + is_secure = property( + lambda self: self.environ["wsgi.url_scheme"] == "https", + doc="`True` if the request is secure.", + ) + is_multithread = environ_property( + "wsgi.multithread", + doc="""boolean that is `True` if the application is served by a + multithreaded WSGI server.""", + ) + is_multiprocess = environ_property( + "wsgi.multiprocess", + doc="""boolean that is `True` if the application is served by a + WSGI server that spawns multiple processes.""", + ) + is_run_once = environ_property( + "wsgi.run_once", + doc="""boolean that is `True` if the application will be + executed only once in a process lifetime. This is the case for + CGI for example, but it's not guaranteed that the execution only + happens one time.""", + ) def _assert_not_shallow(request): if request.shallow: - raise RuntimeError('A shallow request tried to consume ' - 'form data. If you really want to do ' - 'that, set `shallow` to False.') + raise RuntimeError( + "A shallow request tried to consume form data. If you really" + " want to do that, set `shallow` to False." + ) diff --git a/src/werkzeug/wrappers/base_response.py b/src/werkzeug/wrappers/base_response.py index d7a8763b..d944a7d2 100644 --- a/src/werkzeug/wrappers/base_response.py +++ b/src/werkzeug/wrappers/base_response.py @@ -22,6 +22,7 @@ def _run_wsgi_app(*args): """ global _run_wsgi_app from ..test import run_wsgi_app as _run_wsgi_app + return _run_wsgi_app(*args) @@ -36,7 +37,7 @@ def _warn_if_string(iterable): " client one character at a time. This is almost never" " intended behavior, use 'response.data' to assign strings" " to the response object.", - stacklevel=2 + stacklevel=2, ) @@ -129,13 +130,13 @@ class BaseResponse(object): """ #: the charset of the response. - charset = 'utf-8' + charset = "utf-8" #: the default status if none is provided. default_status = 200 #: the default mimetype if none is provided. - default_mimetype = 'text/plain' + default_mimetype = "text/plain" #: if set to `False` accessing properties on the response object will #: not try to consume the response iterator and convert it into a list. @@ -169,8 +170,15 @@ class BaseResponse(object): #: .. _`cookie`: http://browsercookielimits.squawky.net/ max_cookie_size = 4093 - def __init__(self, response=None, status=None, headers=None, - mimetype=None, content_type=None, direct_passthrough=False): + def __init__( + self, + response=None, + status=None, + headers=None, + mimetype=None, + content_type=None, + direct_passthrough=False, + ): if isinstance(headers, Headers): self.headers = headers elif not headers: @@ -179,13 +187,13 @@ class BaseResponse(object): self.headers = Headers(headers) if content_type is None: - if mimetype is None and 'content-type' not in self.headers: + if mimetype is None and "content-type" not in self.headers: mimetype = self.default_mimetype if mimetype is not None: mimetype = get_content_type(mimetype, self.charset) content_type = mimetype if content_type is not None: - self.headers['Content-Type'] = content_type + self.headers["Content-Type"] = content_type if status is None: status = self.default_status if isinstance(status, integer_types): @@ -218,14 +226,10 @@ class BaseResponse(object): def __repr__(self): if self.is_sequence: - body_info = '%d bytes' % sum(map(len, self.iter_encoded())) + body_info = "%d bytes" % sum(map(len, self.iter_encoded())) else: - body_info = 'streamed' if self.is_streamed else 'likely-streamed' - return '<%s %s [%s]>' % ( - self.__class__.__name__, - body_info, - self.status - ) + body_info = "streamed" if self.is_streamed else "likely-streamed" + return "<%s %s [%s]>" % (self.__class__.__name__, body_info, self.status) @classmethod def force_type(cls, response, environ=None): @@ -258,8 +262,10 @@ class BaseResponse(object): """ if not isinstance(response, BaseResponse): if environ is None: - raise TypeError('cannot convert WSGI application into ' - 'response objects without an environ') + raise TypeError( + "cannot convert WSGI application into response" + " objects without an environ" + ) response = BaseResponse(*_run_wsgi_app(response, environ)) response.__class__ = cls return response @@ -286,11 +292,13 @@ class BaseResponse(object): def _set_status_code(self, code): self._status_code = code try: - self._status = '%d %s' % (code, HTTP_STATUS_CODES[code].upper()) + self._status = "%d %s" % (code, HTTP_STATUS_CODES[code].upper()) except KeyError: - self._status = '%d UNKNOWN' % code - status_code = property(_get_status_code, _set_status_code, - doc='The HTTP Status code as number') + self._status = "%d UNKNOWN" % code + + status_code = property( + _get_status_code, _set_status_code, doc="The HTTP Status code as number" + ) del _get_status_code, _set_status_code def _get_status(self): @@ -300,16 +308,17 @@ class BaseResponse(object): try: self._status = to_native(value) except AttributeError: - raise TypeError('Invalid status argument') + raise TypeError("Invalid status argument") try: self._status_code = int(self._status.split(None, 1)[0]) except ValueError: self._status_code = 0 - self._status = '0 %s' % self._status + self._status = "0 %s" % self._status except IndexError: - raise ValueError('Empty status argument') - status = property(_get_status, _set_status, doc='The HTTP Status code') + raise ValueError("Empty status argument") + + status = property(_get_status, _set_status, doc="The HTTP Status code") del _get_status, _set_status def get_data(self, as_text=False): @@ -326,7 +335,7 @@ class BaseResponse(object): .. versionadded:: 0.9 """ self._ensure_sequence() - rv = b''.join(self.iter_encoded()) + rv = b"".join(self.iter_encoded()) if as_text: rv = rv.decode(self.charset) return rv @@ -346,12 +355,12 @@ class BaseResponse(object): value = bytes(value) self.response = [value] if self.automatically_set_content_length: - self.headers['Content-Length'] = str(len(value)) + self.headers["Content-Length"] = str(len(value)) data = property( get_data, set_data, - doc='A descriptor that calls :meth:`get_data` and :meth:`set_data`.' + doc="A descriptor that calls :meth:`get_data` and :meth:`set_data`.", ) def calculate_content_length(self): @@ -375,14 +384,16 @@ class BaseResponse(object): self.response = list(self.response) return if self.direct_passthrough: - raise RuntimeError('Attempted implicit sequence conversion ' - 'but the response object is in direct ' - 'passthrough mode.') + raise RuntimeError( + "Attempted implicit sequence conversion but the" + " response object is in direct passthrough mode." + ) if not self.implicit_sequence_conversion: - raise RuntimeError('The response object required the iterable ' - 'to be a sequence, but the implicit ' - 'conversion was disabled. Call ' - 'make_sequence() yourself.') + raise RuntimeError( + "The response object required the iterable to be a" + " sequence, but the implicit conversion was disabled." + " Call make_sequence() yourself." + ) self.make_sequence() def make_sequence(self): @@ -397,7 +408,7 @@ class BaseResponse(object): # if we consume an iterable we have to ensure that the close # method of the iterable is called if available when we tear # down the response - close = getattr(self.response, 'close', None) + close = getattr(self.response, "close", None) self.response = list(self.iter_encoded()) if close is not None: self.call_on_close(close) @@ -415,9 +426,18 @@ class BaseResponse(object): # value from get_app_iter or iter_encoded. return _iter_encoded(self.response, self.charset) - def set_cookie(self, key, value='', max_age=None, expires=None, - path='/', domain=None, secure=False, httponly=False, - samesite=None): + def set_cookie( + self, + key, + value="", + max_age=None, + expires=None, + path="/", + domain=None, + secure=False, + httponly=False, + samesite=None, + ): """Sets a cookie. The parameters are the same as in the cookie `Morsel` object in the Python standard library but it accepts unicode data, too. @@ -445,21 +465,24 @@ class BaseResponse(object): be attached to requests if those requests are "same-site". """ - self.headers.add('Set-Cookie', dump_cookie( - key, - value=value, - max_age=max_age, - expires=expires, - path=path, - domain=domain, - secure=secure, - httponly=httponly, - charset=self.charset, - max_size=self.max_cookie_size, - samesite=samesite - )) - - def delete_cookie(self, key, path='/', domain=None): + self.headers.add( + "Set-Cookie", + dump_cookie( + key, + value=value, + max_age=max_age, + expires=expires, + path=path, + domain=domain, + secure=secure, + httponly=httponly, + charset=self.charset, + max_size=self.max_cookie_size, + samesite=samesite, + ), + ) + + def delete_cookie(self, key, path="/", domain=None): """Delete a cookie. Fails silently if key doesn't exist. :param key: the key (name) of the cookie to be deleted. @@ -503,7 +526,7 @@ class BaseResponse(object): .. versionadded:: 0.9 Can now be used in a with statement. """ - if hasattr(self.response, 'close'): + if hasattr(self.response, "close"): self.response.close() for func in self._on_close: func() @@ -525,7 +548,7 @@ class BaseResponse(object): # we explicitly set the length to a list of the *encoded* response # iterator. Even if the implicit sequence conversion is disabled. self.response = list(self.iter_encoded()) - self.headers['Content-Length'] = str(sum(map(len, self.response))) + self.headers["Content-Length"] = str(sum(map(len, self.response))) def get_wsgi_headers(self, environ): """This is automatically called right before the response is started @@ -562,11 +585,11 @@ class BaseResponse(object): # speedup. for key, value in headers: ikey = key.lower() - if ikey == u'location': + if ikey == u"location": location = value - elif ikey == u'content-location': + elif ikey == u"content-location": content_location = value - elif ikey == u'content-length': + elif ikey == u"content-length": content_length = value # make sure the location header is an absolute URL @@ -583,17 +606,17 @@ class BaseResponse(object): current_url = iri_to_uri(current_url) location = url_join(current_url, location) if location != old_location: - headers['Location'] = location + headers["Location"] = location # make sure the content location is a URL - if content_location is not None and \ - isinstance(content_location, text_type): - headers['Content-Location'] = iri_to_uri(content_location) + if content_location is not None and isinstance(content_location, text_type): + headers["Content-Location"] = iri_to_uri(content_location) if 100 <= status < 200 or status == 204: - # Per section 3.3.2 of RFC 7230, "a server MUST NOT send a Content-Length header field - # in any response with a status code of 1xx (Informational) or 204 (No Content)." - headers.remove('Content-Length') + # Per section 3.3.2 of RFC 7230, "a server MUST NOT send a + # Content-Length header field in any response with a status + # code of 1xx (Informational) or 204 (No Content)." + headers.remove("Content-Length") elif status == 304: remove_entity_headers(headers) @@ -602,19 +625,21 @@ class BaseResponse(object): # flattening the iterator or encoding of unicode strings in # the response. We however should not do that if we have a 304 # response. - if self.automatically_set_content_length and \ - self.is_sequence and content_length is None and \ - status not in (204, 304) and \ - not (100 <= status < 200): + if ( + self.automatically_set_content_length + and self.is_sequence + and content_length is None + and status not in (204, 304) + and not (100 <= status < 200) + ): try: - content_length = sum(len(to_bytes(x, 'ascii')) - for x in self.response) + content_length = sum(len(to_bytes(x, "ascii")) for x in self.response) except UnicodeError: # aha, something non-bytestringy in there, too bad, we # can't safely figure out the length of the response. pass else: - headers['Content-Length'] = str(content_length) + headers["Content-Length"] = str(content_length) return headers @@ -633,8 +658,11 @@ class BaseResponse(object): :return: a response iterable. """ status = self.status_code - if environ['REQUEST_METHOD'] == 'HEAD' or \ - 100 <= status < 200 or status in (204, 304): + if ( + environ["REQUEST_METHOD"] == "HEAD" + or 100 <= status < 200 + or status in (204, 304) + ): iterable = () elif self.direct_passthrough: if __debug__: diff --git a/src/werkzeug/wrappers/common_descriptors.py b/src/werkzeug/wrappers/common_descriptors.py index 3ad0474a..e4107ee0 100644 --- a/src/werkzeug/wrappers/common_descriptors.py +++ b/src/werkzeug/wrappers/common_descriptors.py @@ -26,11 +26,13 @@ class CommonRequestDescriptorsMixin(object): .. versionadded:: 0.5 """ - content_type = environ_property('CONTENT_TYPE', doc=''' - The Content-Type entity-header field indicates the media type of - the entity-body sent to the recipient or, in the case of the HEAD - method, the media type that would have been sent had the request - been a GET.''') + content_type = environ_property( + "CONTENT_TYPE", + doc="""The Content-Type entity-header field indicates the media + type of the entity-body sent to the recipient or, in the case of + the HEAD method, the media type that would have been sent had + the request been a GET.""", + ) @cached_property def content_length(self): @@ -41,40 +43,58 @@ class CommonRequestDescriptorsMixin(object): """ return get_content_length(self.environ) - content_encoding = environ_property('HTTP_CONTENT_ENCODING', doc=''' - The Content-Encoding entity-header field is used as a modifier to the - media-type. When present, its value indicates what additional content - codings have been applied to the entity-body, and thus what decoding - mechanisms must be applied in order to obtain the media-type - referenced by the Content-Type header field. - - .. versionadded:: 0.9''') - content_md5 = environ_property('HTTP_CONTENT_MD5', doc=''' - The Content-MD5 entity-header field, as defined in RFC 1864, is an - MD5 digest of the entity-body for the purpose of providing an - end-to-end message integrity check (MIC) of the entity-body. (Note: - a MIC is good for detecting accidental modification of the - entity-body in transit, but is not proof against malicious attacks.) - - .. versionadded:: 0.9''') - referrer = environ_property('HTTP_REFERER', doc=''' - The Referer[sic] request-header field allows the client to specify, - for the server's benefit, the address (URI) of the resource from which - the Request-URI was obtained (the "referrer", although the header - field is misspelled).''') - date = environ_property('HTTP_DATE', None, parse_date, doc=''' - The Date general-header field represents the date and time at which - the message was originated, having the same semantics as orig-date - in RFC 822.''') - max_forwards = environ_property('HTTP_MAX_FORWARDS', None, int, doc=''' - The Max-Forwards request-header field provides a mechanism with the - TRACE and OPTIONS methods to limit the number of proxies or gateways - that can forward the request to the next inbound server.''') + content_encoding = environ_property( + "HTTP_CONTENT_ENCODING", + doc="""The Content-Encoding entity-header field is used as a + modifier to the media-type. When present, its value indicates + what additional content codings have been applied to the + entity-body, and thus what decoding mechanisms must be applied + in order to obtain the media-type referenced by the Content-Type + header field. + + .. versionadded:: 0.9""", + ) + content_md5 = environ_property( + "HTTP_CONTENT_MD5", + doc="""The Content-MD5 entity-header field, as defined in + RFC 1864, is an MD5 digest of the entity-body for the purpose of + providing an end-to-end message integrity check (MIC) of the + entity-body. (Note: a MIC is good for detecting accidental + modification of the entity-body in transit, but is not proof + against malicious attacks.) + + .. versionadded:: 0.9""", + ) + referrer = environ_property( + "HTTP_REFERER", + doc="""The Referer[sic] request-header field allows the client + to specify, for the server's benefit, the address (URI) of the + resource from which the Request-URI was obtained (the + "referrer", although the header field is misspelled).""", + ) + date = environ_property( + "HTTP_DATE", + None, + parse_date, + doc="""The Date general-header field represents the date and + time at which the message was originated, having the same + semantics as orig-date in RFC 822.""", + ) + max_forwards = environ_property( + "HTTP_MAX_FORWARDS", + None, + int, + doc="""The Max-Forwards request-header field provides a + mechanism with the TRACE and OPTIONS methods to limit the number + of proxies or gateways that can forward the request to the next + inbound server.""", + ) def _parse_content_type(self): - if not hasattr(self, '_parsed_content_type'): - self._parsed_content_type = \ - parse_options_header(self.environ.get('CONTENT_TYPE', '')) + if not hasattr(self, "_parsed_content_type"): + self._parsed_content_type = parse_options_header( + self.environ.get("CONTENT_TYPE", "") + ) @property def mimetype(self): @@ -103,7 +123,7 @@ class CommonRequestDescriptorsMixin(object): optional behavior from the viewpoint of the protocol; however, some systems MAY require that behavior be consistent with the directives. """ - return parse_set_header(self.environ.get('HTTP_PRAGMA', '')) + return parse_set_header(self.environ.get("HTTP_PRAGMA", "")) class CommonResponseDescriptorsMixin(object): @@ -112,115 +132,157 @@ class CommonResponseDescriptorsMixin(object): HTTP headers with automatic type conversion. """ - def _get_mimetype(self): - ct = self.headers.get('content-type') + @property + def mimetype(self): + """The mimetype (content type without charset etc.)""" + ct = self.headers.get("content-type") if ct: - return ct.split(';')[0].strip() + return ct.split(";")[0].strip() + + @mimetype.setter + def mimetype(self, value): + self.headers["Content-Type"] = get_content_type(value, self.charset) + + @property + def mimetype_params(self): + """The mimetype parameters as dict. For example if the + content type is ``text/html; charset=utf-8`` the params would be + ``{'charset': 'utf-8'}``. - def _set_mimetype(self, value): - self.headers['Content-Type'] = get_content_type(value, self.charset) + .. versionadded:: 0.5 + """ - def _get_mimetype_params(self): def on_update(d): - self.headers['Content-Type'] = \ - dump_options_header(self.mimetype, d) - d = parse_options_header(self.headers.get('content-type', ''))[1] + self.headers["Content-Type"] = dump_options_header(self.mimetype, d) + + d = parse_options_header(self.headers.get("content-type", ""))[1] return CallbackDict(d, on_update) - mimetype = property(_get_mimetype, _set_mimetype, doc=''' - The mimetype (content type without charset etc.)''') - mimetype_params = property(_get_mimetype_params, doc=''' - The mimetype parameters as dict. For example if the content - type is ``text/html; charset=utf-8`` the params would be - ``{'charset': 'utf-8'}``. + location = header_property( + "Location", + doc="""The Location response-header field is used to redirect + the recipient to a location other than the Request-URI for + completion of the request or identification of a new + resource.""", + ) + age = header_property( + "Age", + None, + parse_age, + dump_age, + doc="""The Age response-header field conveys the sender's + estimate of the amount of time since the response (or its + revalidation) was generated at the origin server. - .. versionadded:: 0.5 - ''') - location = header_property('Location', doc=''' - The Location response-header field is used to redirect the recipient - to a location other than the Request-URI for completion of the request - or identification of a new resource.''') - age = header_property('Age', None, parse_age, dump_age, doc=''' - The Age response-header field conveys the sender's estimate of the - amount of time since the response (or its revalidation) was - generated at the origin server. - - Age values are non-negative decimal integers, representing time in - seconds.''') - content_type = header_property('Content-Type', doc=''' - The Content-Type entity-header field indicates the media type of the - entity-body sent to the recipient or, in the case of the HEAD method, - the media type that would have been sent had the request been a GET. - ''') - content_length = header_property('Content-Length', None, int, str, doc=''' - The Content-Length entity-header field indicates the size of the - entity-body, in decimal number of OCTETs, sent to the recipient or, - in the case of the HEAD method, the size of the entity-body that would - have been sent had the request been a GET.''') - content_location = header_property('Content-Location', doc=''' - The Content-Location entity-header field MAY be used to supply the - resource location for the entity enclosed in the message when that - entity is accessible from a location separate from the requested - resource's URI.''') - content_encoding = header_property('Content-Encoding', doc=''' - The Content-Encoding entity-header field is used as a modifier to the - media-type. When present, its value indicates what additional content - codings have been applied to the entity-body, and thus what decoding - mechanisms must be applied in order to obtain the media-type - referenced by the Content-Type header field.''') - content_md5 = header_property('Content-MD5', doc=''' - The Content-MD5 entity-header field, as defined in RFC 1864, is an - MD5 digest of the entity-body for the purpose of providing an - end-to-end message integrity check (MIC) of the entity-body. (Note: - a MIC is good for detecting accidental modification of the - entity-body in transit, but is not proof against malicious attacks.) - ''') - date = header_property('Date', None, parse_date, http_date, doc=''' - The Date general-header field represents the date and time at which - the message was originated, having the same semantics as orig-date - in RFC 822.''') - expires = header_property('Expires', None, parse_date, http_date, doc=''' - The Expires entity-header field gives the date/time after which the - response is considered stale. A stale cache entry may not normally be - returned by a cache.''') - last_modified = header_property('Last-Modified', None, parse_date, - http_date, doc=''' - The Last-Modified entity-header field indicates the date and time at - which the origin server believes the variant was last modified.''') - - def _get_retry_after(self): - value = self.headers.get('retry-after') + Age values are non-negative decimal integers, representing time + in seconds.""", + ) + content_type = header_property( + "Content-Type", + doc="""The Content-Type entity-header field indicates the media + type of the entity-body sent to the recipient or, in the case of + the HEAD method, the media type that would have been sent had + the request been a GET.""", + ) + content_length = header_property( + "Content-Length", + None, + int, + str, + doc="""The Content-Length entity-header field indicates the size + of the entity-body, in decimal number of OCTETs, sent to the + recipient or, in the case of the HEAD method, the size of the + entity-body that would have been sent had the request been a + GET.""", + ) + content_location = header_property( + "Content-Location", + doc="""The Content-Location entity-header field MAY be used to + supply the resource location for the entity enclosed in the + message when that entity is accessible from a location separate + from the requested resource's URI.""", + ) + content_encoding = header_property( + "Content-Encoding", + doc="""The Content-Encoding entity-header field is used as a + modifier to the media-type. When present, its value indicates + what additional content codings have been applied to the + entity-body, and thus what decoding mechanisms must be applied + in order to obtain the media-type referenced by the Content-Type + header field.""", + ) + content_md5 = header_property( + "Content-MD5", + doc="""The Content-MD5 entity-header field, as defined in + RFC 1864, is an MD5 digest of the entity-body for the purpose of + providing an end-to-end message integrity check (MIC) of the + entity-body. (Note: a MIC is good for detecting accidental + modification of the entity-body in transit, but is not proof + against malicious attacks.)""", + ) + date = header_property( + "Date", + None, + parse_date, + http_date, + doc="""The Date general-header field represents the date and + time at which the message was originated, having the same + semantics as orig-date in RFC 822.""", + ) + expires = header_property( + "Expires", + None, + parse_date, + http_date, + doc="""The Expires entity-header field gives the date/time after + which the response is considered stale. A stale cache entry may + not normally be returned by a cache.""", + ) + last_modified = header_property( + "Last-Modified", + None, + parse_date, + http_date, + doc="""The Last-Modified entity-header field indicates the date + and time at which the origin server believes the variant was + last modified.""", + ) + + @property + def retry_after(self): + """The Retry-After response-header field can be used with a + 503 (Service Unavailable) response to indicate how long the + service is expected to be unavailable to the requesting client. + + Time in seconds until expiration or date. + """ + value = self.headers.get("retry-after") if value is None: return elif value.isdigit(): return datetime.utcnow() + timedelta(seconds=int(value)) return parse_date(value) - def _set_retry_after(self, value): + @retry_after.setter + def retry_after(self, value): if value is None: - if 'retry-after' in self.headers: - del self.headers['retry-after'] + if "retry-after" in self.headers: + del self.headers["retry-after"] return elif isinstance(value, datetime): value = http_date(value) else: value = str(value) - self.headers['Retry-After'] = value - - retry_after = property(_get_retry_after, _set_retry_after, doc=''' - The Retry-After response-header field can be used with a 503 (Service - Unavailable) response to indicate how long the service is expected - to be unavailable to the requesting client. - - Time in seconds until expiration or date.''') + self.headers["Retry-After"] = value - def _set_property(name, doc=None): + def _set_property(name, doc=None): # noqa: B902 def fget(self): def on_update(header_set): if not header_set and name in self.headers: del self.headers[name] elif header_set: self.headers[name] = header_set.to_header() + return parse_set_header(self.headers.get(name), on_update) def fset(self, value): @@ -230,24 +292,31 @@ class CommonResponseDescriptorsMixin(object): self.headers[name] = value else: self.headers[name] = dump_header(value) + return property(fget, fset, doc=doc) - vary = _set_property('Vary', doc=''' - The Vary field value indicates the set of request-header fields that - fully determines, while the response is fresh, whether a cache is - permitted to use the response to reply to a subsequent request - without revalidation.''') - content_language = _set_property('Content-Language', doc=''' - The Content-Language entity-header field describes the natural - language(s) of the intended audience for the enclosed entity. Note - that this might not be equivalent to all the languages used within - the entity-body.''') - allow = _set_property('Allow', doc=''' - The Allow entity-header field lists the set of methods supported - by the resource identified by the Request-URI. The purpose of this - field is strictly to inform the recipient of valid methods - associated with the resource. An Allow header field MUST be - present in a 405 (Method Not Allowed) response.''') - - del _set_property, _get_mimetype, _set_mimetype, _get_retry_after, \ - _set_retry_after + vary = _set_property( + "Vary", + doc="""The Vary field value indicates the set of request-header + fields that fully determines, while the response is fresh, + whether a cache is permitted to use the response to reply to a + subsequent request without revalidation.""", + ) + content_language = _set_property( + "Content-Language", + doc="""The Content-Language entity-header field describes the + natural language(s) of the intended audience for the enclosed + entity. Note that this might not be equivalent to all the + languages used within the entity-body.""", + ) + allow = _set_property( + "Allow", + doc="""The Allow entity-header field lists the set of methods + supported by the resource identified by the Request-URI. The + purpose of this field is strictly to inform the recipient of + valid methods associated with the resource. An Allow header + field MUST be present in a 405 (Method Not Allowed) + response.""", + ) + + del _set_property diff --git a/src/werkzeug/wrappers/etag.py b/src/werkzeug/wrappers/etag.py index 7ad12f61..0733506f 100644 --- a/src/werkzeug/wrappers/etag.py +++ b/src/werkzeug/wrappers/etag.py @@ -31,9 +31,8 @@ class ETagRequestMixin(object): """A :class:`~werkzeug.datastructures.RequestCacheControl` object for the incoming cache control headers. """ - cache_control = self.environ.get('HTTP_CACHE_CONTROL') - return parse_cache_control_header(cache_control, None, - RequestCacheControl) + cache_control = self.environ.get("HTTP_CACHE_CONTROL") + return parse_cache_control_header(cache_control, None, RequestCacheControl) @cached_property def if_match(self): @@ -41,7 +40,7 @@ class ETagRequestMixin(object): :rtype: :class:`~werkzeug.datastructures.ETags` """ - return parse_etags(self.environ.get('HTTP_IF_MATCH')) + return parse_etags(self.environ.get("HTTP_IF_MATCH")) @cached_property def if_none_match(self): @@ -49,17 +48,17 @@ class ETagRequestMixin(object): :rtype: :class:`~werkzeug.datastructures.ETags` """ - return parse_etags(self.environ.get('HTTP_IF_NONE_MATCH')) + return parse_etags(self.environ.get("HTTP_IF_NONE_MATCH")) @cached_property def if_modified_since(self): """The parsed `If-Modified-Since` header as datetime object.""" - return parse_date(self.environ.get('HTTP_IF_MODIFIED_SINCE')) + return parse_date(self.environ.get("HTTP_IF_MODIFIED_SINCE")) @cached_property def if_unmodified_since(self): """The parsed `If-Unmodified-Since` header as datetime object.""" - return parse_date(self.environ.get('HTTP_IF_UNMODIFIED_SINCE')) + return parse_date(self.environ.get("HTTP_IF_UNMODIFIED_SINCE")) @cached_property def if_range(self): @@ -69,7 +68,7 @@ class ETagRequestMixin(object): :rtype: :class:`~werkzeug.datastructures.IfRange` """ - return parse_if_range_header(self.environ.get('HTTP_IF_RANGE')) + return parse_if_range_header(self.environ.get("HTTP_IF_RANGE")) @cached_property def range(self): @@ -79,7 +78,7 @@ class ETagRequestMixin(object): :rtype: :class:`~werkzeug.datastructures.Range` """ - return parse_range_header(self.environ.get('HTTP_RANGE')) + return parse_range_header(self.environ.get("HTTP_RANGE")) class ETagResponseMixin(object): @@ -99,14 +98,16 @@ class ETagResponseMixin(object): directives that MUST be obeyed by all caching mechanisms along the request/response chain. """ + def on_update(cache_control): - if not cache_control and 'cache-control' in self.headers: - del self.headers['cache-control'] + if not cache_control and "cache-control" in self.headers: + del self.headers["cache-control"] elif cache_control: - self.headers['Cache-Control'] = cache_control.to_header() - return parse_cache_control_header(self.headers.get('cache-control'), - on_update, - ResponseCacheControl) + self.headers["Cache-Control"] = cache_control.to_header() + + return parse_cache_control_header( + self.headers.get("cache-control"), on_update, ResponseCacheControl + ) def _wrap_response(self, start, length): """Wrap existing Response in case of Range Request context.""" @@ -118,12 +119,15 @@ class ETagResponseMixin(object): resource is considered unchanged when compared with `If-Range` header. """ return ( - 'HTTP_IF_RANGE' not in environ + "HTTP_IF_RANGE" not in environ or not is_resource_modified( - environ, self.headers.get('etag'), None, - self.headers.get('last-modified'), ignore_if_range=False + environ, + self.headers.get("etag"), + None, + self.headers.get("last-modified"), + ignore_if_range=False, ) - ) and 'HTTP_RANGE' in environ + ) and "HTTP_RANGE" in environ def _process_range_request(self, environ, complete_length=None, accept_ranges=None): """Handle Range Request related headers (RFC7233). If `Accept-Ranges` @@ -136,13 +140,14 @@ class ETagResponseMixin(object): :raises: :class:`~werkzeug.exceptions.RequestedRangeNotSatisfiable` if `Range` header could not be parsed or satisfied. """ - from werkzeug.exceptions import RequestedRangeNotSatisfiable + from ..exceptions import RequestedRangeNotSatisfiable + if accept_ranges is None: return False - self.headers['Accept-Ranges'] = accept_ranges + self.headers["Accept-Ranges"] = accept_ranges if not self._is_range_request_processable(environ) or complete_length is None: return False - parsed_range = parse_range_header(environ.get('HTTP_RANGE')) + parsed_range = parse_range_header(environ.get("HTTP_RANGE")) if parsed_range is None: raise RequestedRangeNotSatisfiable(complete_length) range_tuple = parsed_range.range_for_length(complete_length) @@ -153,15 +158,16 @@ class ETagResponseMixin(object): # Be sure not to send 206 response # if requested range is the full content. if content_length != complete_length: - self.headers['Content-Length'] = content_length + self.headers["Content-Length"] = content_length self.content_range = content_range_header self.status_code = 206 self._wrap_response(range_tuple[0], content_length) return True return False - def make_conditional(self, request_or_environ, accept_ranges=False, - complete_length=None): + def make_conditional( + self, request_or_environ, accept_ranges=False, complete_length=None + ): """Make the response conditional to the request. This method works best if an etag was defined for the response already. The `add_etag` method can be used to do that. If called without etag just the date @@ -199,43 +205,48 @@ class ETagResponseMixin(object): if `Range` header could not be parsed or satisfied. """ environ = _get_environ(request_or_environ) - if environ['REQUEST_METHOD'] in ('GET', 'HEAD'): + if environ["REQUEST_METHOD"] in ("GET", "HEAD"): # if the date is not in the headers, add it now. We however # will not override an already existing header. Unfortunately # this header will be overriden by many WSGI servers including # wsgiref. - if 'date' not in self.headers: - self.headers['Date'] = http_date() + if "date" not in self.headers: + self.headers["Date"] = http_date() accept_ranges = _clean_accept_ranges(accept_ranges) is206 = self._process_range_request(environ, complete_length, accept_ranges) if not is206 and not is_resource_modified( - environ, self.headers.get('etag'), None, - self.headers.get('last-modified') + environ, + self.headers.get("etag"), + None, + self.headers.get("last-modified"), ): - if parse_etags(environ.get('HTTP_IF_MATCH')): + if parse_etags(environ.get("HTTP_IF_MATCH")): self.status_code = 412 else: self.status_code = 304 - if self.automatically_set_content_length and 'content-length' not in self.headers: + if ( + self.automatically_set_content_length + and "content-length" not in self.headers + ): length = self.calculate_content_length() if length is not None: - self.headers['Content-Length'] = length + self.headers["Content-Length"] = length return self def add_etag(self, overwrite=False, weak=False): """Add an etag for the current response if there is none yet.""" - if overwrite or 'etag' not in self.headers: + if overwrite or "etag" not in self.headers: self.set_etag(generate_etag(self.get_data()), weak) def set_etag(self, etag, weak=False): """Set the etag, and override the old one if there was one.""" - self.headers['ETag'] = quote_etag(etag, weak) + self.headers["ETag"] = quote_etag(etag, weak) def get_etag(self): """Return a tuple in the form ``(etag, is_weak)``. If there is no ETag the return value is ``(None, None)``. """ - return unquote_etag(self.headers.get('ETag')) + return unquote_etag(self.headers.get("ETag")) def freeze(self, no_etag=False): """Call this method if you want to make your response object ready for @@ -246,22 +257,25 @@ class ETagResponseMixin(object): self.add_etag() super(ETagResponseMixin, self).freeze() - accept_ranges = header_property('Accept-Ranges', doc=''' - The `Accept-Ranges` header. Even though the name would indicate - that multiple values are supported, it must be one string token only. + accept_ranges = header_property( + "Accept-Ranges", + doc="""The `Accept-Ranges` header. Even though the name would + indicate that multiple values are supported, it must be one + string token only. The values ``'bytes'`` and ``'none'`` are common. - .. versionadded:: 0.7''') + .. versionadded:: 0.7""", + ) def _get_content_range(self): def on_update(rng): if not rng: - del self.headers['content-range'] + del self.headers["content-range"] else: - self.headers['Content-Range'] = rng.to_header() - rv = parse_content_range_header(self.headers.get('content-range'), - on_update) + self.headers["Content-Range"] = rng.to_header() + + rv = parse_content_range_header(self.headers.get("content-range"), on_update) # always provide a content range object to make the descriptor # more user friendly. It provides an unset() method that can be # used to remove the header quickly. @@ -271,16 +285,20 @@ class ETagResponseMixin(object): def _set_content_range(self, value): if not value: - del self.headers['content-range'] + del self.headers["content-range"] elif isinstance(value, string_types): - self.headers['Content-Range'] = value + self.headers["Content-Range"] = value else: - self.headers['Content-Range'] = value.to_header() - content_range = property(_get_content_range, _set_content_range, doc=''' - The `Content-Range` header as - :class:`~werkzeug.datastructures.ContentRange` object. Even if the - header is not set it wil provide such an object for easier + self.headers["Content-Range"] = value.to_header() + + content_range = property( + _get_content_range, + _set_content_range, + doc="""The ``Content-Range`` header as + :class:`~werkzeug.datastructures.ContentRange` object. Even if + the header is not set it wil provide such an object for easier manipulation. - .. versionadded:: 0.7''') + .. versionadded:: 0.7""", + ) del _get_content_range, _set_content_range diff --git a/src/werkzeug/wrappers/json.py b/src/werkzeug/wrappers/json.py index 851806dd..6d5dc33d 100644 --- a/src/werkzeug/wrappers/json.py +++ b/src/werkzeug/wrappers/json.py @@ -75,9 +75,9 @@ class JSONMixin(object): """ mt = self.mimetype return ( - mt == 'application/json' - or mt.startswith('application/') - and mt.endswith('+json') + mt == "application/json" + or mt.startswith("application/") + and mt.endswith("+json") ) def _get_data_for_json(self, cache): @@ -142,4 +142,4 @@ class JSONMixin(object): for :meth:`get_json`. The default implementation raises :exc:`~werkzeug.exceptions.BadRequest`. """ - raise BadRequest('Failed to decode JSON object: {0}'.format(e)) + raise BadRequest("Failed to decode JSON object: {0}".format(e)) diff --git a/src/werkzeug/wrappers/request.py b/src/werkzeug/wrappers/request.py index 5f1b00c8..d1c71b64 100644 --- a/src/werkzeug/wrappers/request.py +++ b/src/werkzeug/wrappers/request.py @@ -6,9 +6,14 @@ from .etag import ETagRequestMixin from .user_agent import UserAgentMixin -class Request(BaseRequest, AcceptMixin, ETagRequestMixin, - UserAgentMixin, AuthorizationMixin, - CommonRequestDescriptorsMixin): +class Request( + BaseRequest, + AcceptMixin, + ETagRequestMixin, + UserAgentMixin, + AuthorizationMixin, + CommonRequestDescriptorsMixin, +): """Full featured request object implementing the following mixins: - :class:`AcceptMixin` for accept header parsing diff --git a/src/werkzeug/wrappers/response.py b/src/werkzeug/wrappers/response.py index 8b7dc7ba..cd86cacd 100644 --- a/src/werkzeug/wrappers/response.py +++ b/src/werkzeug/wrappers/response.py @@ -11,7 +11,7 @@ class ResponseStream(object): iterable of the response object. """ - mode = 'wb+' + mode = "wb+" def __init__(self, response): self.response = response @@ -19,10 +19,10 @@ class ResponseStream(object): def write(self, value): if self.closed: - raise ValueError('I/O operation on closed file') + raise ValueError("I/O operation on closed file") self.response._ensure_sequence(mutable=True) self.response.response.append(value) - self.response.headers.pop('Content-Length', None) + self.response.headers.pop("Content-Length", None) return len(value) def writelines(self, seq): @@ -34,11 +34,11 @@ class ResponseStream(object): def flush(self): if self.closed: - raise ValueError('I/O operation on closed file') + raise ValueError("I/O operation on closed file") def isatty(self): if self.closed: - raise ValueError('I/O operation on closed file') + raise ValueError("I/O operation on closed file") return False def tell(self): @@ -62,9 +62,13 @@ class ResponseStreamMixin(object): return ResponseStream(self) -class Response(BaseResponse, ETagResponseMixin, ResponseStreamMixin, - CommonResponseDescriptorsMixin, - WWWAuthenticateMixin): +class Response( + BaseResponse, + ETagResponseMixin, + ResponseStreamMixin, + CommonResponseDescriptorsMixin, + WWWAuthenticateMixin, +): """Full featured response object implementing the following mixins: - :class:`ETagResponseMixin` for etag and cache control handling diff --git a/src/werkzeug/wrappers/user_agent.py b/src/werkzeug/wrappers/user_agent.py index be106efe..72588dd9 100644 --- a/src/werkzeug/wrappers/user_agent.py +++ b/src/werkzeug/wrappers/user_agent.py @@ -2,13 +2,14 @@ from ..utils import cached_property class UserAgentMixin(object): - """Adds a `user_agent` attribute to the request object which contains the - parsed user agent of the browser that triggered the request as a - :class:`~werkzeug.useragents.UserAgent` object. + """Adds a `user_agent` attribute to the request object which + contains the parsed user agent of the browser that triggered the + request as a :class:`~werkzeug.useragents.UserAgent` object. """ @cached_property def user_agent(self): """The current user agent.""" - from werkzeug.useragents import UserAgent + from ..useragents import UserAgent + return UserAgent(self.environ) diff --git a/src/werkzeug/wsgi.py b/src/werkzeug/wsgi.py index 60966503..f069f2d8 100644 --- a/src/werkzeug/wsgi.py +++ b/src/werkzeug/wsgi.py @@ -11,14 +11,24 @@ import io import re import warnings -from functools import partial, update_wrapper +from functools import partial +from functools import update_wrapper from itertools import chain -from werkzeug._compat import BytesIO, implements_iterator, \ - make_literal_wrapper, string_types, text_type, to_bytes, to_unicode, \ - try_coerce_native, wsgi_get_bytes -from werkzeug._internal import _encode_idna -from werkzeug.urls import uri_to_iri, url_join, url_parse, url_quote +from ._compat import BytesIO +from ._compat import implements_iterator +from ._compat import make_literal_wrapper +from ._compat import string_types +from ._compat import text_type +from ._compat import to_bytes +from ._compat import to_unicode +from ._compat import try_coerce_native +from ._compat import wsgi_get_bytes +from ._internal import _encode_idna +from .urls import uri_to_iri +from .urls import url_join +from .urls import url_parse +from .urls import url_quote def responder(f): @@ -34,8 +44,13 @@ def responder(f): return update_wrapper(lambda *a: f(*a)(*a[-2:]), f) -def get_current_url(environ, root_only=False, strip_querystring=False, - host_only=False, trusted_hosts=None): +def get_current_url( + environ, + root_only=False, + strip_querystring=False, + host_only=False, + trusted_hosts=None, +): """A handy helper function that recreates the full URL as IRI for the current request or parts of it. Here's an example: @@ -70,19 +85,19 @@ def get_current_url(environ, root_only=False, strip_querystring=False, :param trusted_hosts: a list of trusted hosts, see :func:`host_is_trusted` for more information. """ - tmp = [environ['wsgi.url_scheme'], '://', get_host(environ, trusted_hosts)] + tmp = [environ["wsgi.url_scheme"], "://", get_host(environ, trusted_hosts)] cat = tmp.append if host_only: - return uri_to_iri(''.join(tmp) + '/') - cat(url_quote(wsgi_get_bytes(environ.get('SCRIPT_NAME', ''))).rstrip('/')) - cat('/') + return uri_to_iri("".join(tmp) + "/") + cat(url_quote(wsgi_get_bytes(environ.get("SCRIPT_NAME", ""))).rstrip("/")) + cat("/") if not root_only: - cat(url_quote(wsgi_get_bytes(environ.get('PATH_INFO', '')).lstrip(b'/'))) + cat(url_quote(wsgi_get_bytes(environ.get("PATH_INFO", "")).lstrip(b"/"))) if not strip_querystring: qs = get_query_string(environ) if qs: - cat('?' + qs) - return uri_to_iri(''.join(tmp)) + cat("?" + qs) + return uri_to_iri("".join(tmp)) def host_is_trusted(hostname, trusted_list): @@ -103,8 +118,8 @@ def host_is_trusted(hostname, trusted_list): trusted_list = [trusted_list] def _normalize(hostname): - if ':' in hostname: - hostname = hostname.rsplit(':', 1)[0] + if ":" in hostname: + hostname = hostname.rsplit(":", 1)[0] return _encode_idna(hostname) try: @@ -112,7 +127,7 @@ def host_is_trusted(hostname, trusted_list): except UnicodeError: return False for ref in trusted_list: - if ref.startswith('.'): + if ref.startswith("."): ref = ref[1:] suffix_match = True else: @@ -123,7 +138,7 @@ def host_is_trusted(hostname, trusted_list): return False if ref == hostname: return True - if suffix_match and hostname.endswith(b'.' + ref): + if suffix_match and hostname.endswith(b"." + ref): return True return False @@ -144,20 +159,23 @@ def get_host(environ, trusted_hosts=None): :raise ~werkzeug.exceptions.SecurityError: If the host is not trusted. """ - if 'HTTP_HOST' in environ: - rv = environ['HTTP_HOST'] - if environ['wsgi.url_scheme'] == 'http' and rv.endswith(':80'): + if "HTTP_HOST" in environ: + rv = environ["HTTP_HOST"] + if environ["wsgi.url_scheme"] == "http" and rv.endswith(":80"): rv = rv[:-3] - elif environ['wsgi.url_scheme'] == 'https' and rv.endswith(':443'): + elif environ["wsgi.url_scheme"] == "https" and rv.endswith(":443"): rv = rv[:-4] else: - rv = environ['SERVER_NAME'] - if (environ['wsgi.url_scheme'], environ['SERVER_PORT']) not \ - in (('https', '443'), ('http', '80')): - rv += ':' + environ['SERVER_PORT'] + rv = environ["SERVER_NAME"] + if (environ["wsgi.url_scheme"], environ["SERVER_PORT"]) not in ( + ("https", "443"), + ("http", "80"), + ): + rv += ":" + environ["SERVER_PORT"] if trusted_hosts is not None: if not host_is_trusted(rv, trusted_hosts): - from werkzeug.exceptions import SecurityError + from .exceptions import SecurityError + raise SecurityError('Host "%s" is not trusted' % rv) return rv @@ -171,10 +189,10 @@ def get_content_length(environ): :param environ: the WSGI environ to fetch the content length from. """ - if environ.get('HTTP_TRANSFER_ENCODING', '') == 'chunked': + if environ.get("HTTP_TRANSFER_ENCODING", "") == "chunked": return None - content_length = environ.get('CONTENT_LENGTH') + content_length = environ.get("CONTENT_LENGTH") if content_length is not None: try: return max(0, int(content_length)) @@ -199,13 +217,13 @@ def get_input_stream(environ, safe_fallback=True): content length is not set. Disabling this allows infinite streams, which can be a denial-of-service risk. """ - stream = environ['wsgi.input'] + stream = environ["wsgi.input"] content_length = get_content_length(environ) # A wsgi extension that tells us if the input is terminated. In # that case we return the stream unchanged as we know we can safely # read it until the end. - if environ.get('wsgi.input_terminated'): + if environ.get("wsgi.input_terminated"): return stream # If the request doesn't specify a content length, returning the stream is @@ -228,14 +246,14 @@ def get_query_string(environ): :param environ: the WSGI environment object to get the query string from. """ - qs = wsgi_get_bytes(environ.get('QUERY_STRING', '')) + qs = wsgi_get_bytes(environ.get("QUERY_STRING", "")) # QUERY_STRING really should be ascii safe but some browsers # will send us some unicode stuff (I am looking at you IE). # In that case we want to urllib quote it badly. - return try_coerce_native(url_quote(qs, safe=':&%=+$!*\'(),')) + return try_coerce_native(url_quote(qs, safe=":&%=+$!*'(),")) -def get_path_info(environ, charset='utf-8', errors='replace'): +def get_path_info(environ, charset="utf-8", errors="replace"): """Returns the `PATH_INFO` from the WSGI environment and properly decodes it. This also takes care about the WSGI decoding dance on Python 3 environments. if the `charset` is set to `None` a @@ -248,11 +266,11 @@ def get_path_info(environ, charset='utf-8', errors='replace'): decoding should be performed. :param errors: the decoding error handling. """ - path = wsgi_get_bytes(environ.get('PATH_INFO', '')) + path = wsgi_get_bytes(environ.get("PATH_INFO", "")) return to_unicode(path, charset, errors, allow_none_charset=True) -def get_script_name(environ, charset='utf-8', errors='replace'): +def get_script_name(environ, charset="utf-8", errors="replace"): """Returns the `SCRIPT_NAME` from the WSGI environment and properly decodes it. This also takes care about the WSGI decoding dance on Python 3 environments. if the `charset` is set to `None` a @@ -265,11 +283,11 @@ def get_script_name(environ, charset='utf-8', errors='replace'): decoding should be performed. :param errors: the decoding error handling. """ - path = wsgi_get_bytes(environ.get('SCRIPT_NAME', '')) + path = wsgi_get_bytes(environ.get("SCRIPT_NAME", "")) return to_unicode(path, charset, errors, allow_none_charset=True) -def pop_path_info(environ, charset='utf-8', errors='replace'): +def pop_path_info(environ, charset="utf-8", errors="replace"): """Removes and returns the next segment of `PATH_INFO`, pushing it onto `SCRIPT_NAME`. Returns `None` if there is nothing left on `PATH_INFO`. @@ -296,32 +314,32 @@ def pop_path_info(environ, charset='utf-8', errors='replace'): :param environ: the WSGI environment that is modified. """ - path = environ.get('PATH_INFO') + path = environ.get("PATH_INFO") if not path: return None - script_name = environ.get('SCRIPT_NAME', '') + script_name = environ.get("SCRIPT_NAME", "") # shift multiple leading slashes over old_path = path - path = path.lstrip('/') + path = path.lstrip("/") if path != old_path: - script_name += '/' * (len(old_path) - len(path)) + script_name += "/" * (len(old_path) - len(path)) - if '/' not in path: - environ['PATH_INFO'] = '' - environ['SCRIPT_NAME'] = script_name + path + if "/" not in path: + environ["PATH_INFO"] = "" + environ["SCRIPT_NAME"] = script_name + path rv = wsgi_get_bytes(path) else: - segment, path = path.split('/', 1) - environ['PATH_INFO'] = '/' + path - environ['SCRIPT_NAME'] = script_name + segment + segment, path = path.split("/", 1) + environ["PATH_INFO"] = "/" + path + environ["SCRIPT_NAME"] = script_name + segment rv = wsgi_get_bytes(segment) return to_unicode(rv, charset, errors, allow_none_charset=True) -def peek_path_info(environ, charset='utf-8', errors='replace'): +def peek_path_info(environ, charset="utf-8", errors="replace"): """Returns the next segment on the `PATH_INFO` or `None` if there is none. Works like :func:`pop_path_info` without modifying the environment: @@ -342,18 +360,19 @@ def peek_path_info(environ, charset='utf-8', errors='replace'): :param environ: the WSGI environment that is checked. """ - segments = environ.get('PATH_INFO', '').lstrip('/').split('/', 1) + segments = environ.get("PATH_INFO", "").lstrip("/").split("/", 1) if segments: - return to_unicode(wsgi_get_bytes(segments[0]), - charset, errors, allow_none_charset=True) + return to_unicode( + wsgi_get_bytes(segments[0]), charset, errors, allow_none_charset=True + ) def extract_path_info( environ_or_baseurl, path_or_url, - charset='utf-8', - errors='werkzeug.url_quote', - collapse_http_schemes=True + charset="utf-8", + errors="werkzeug.url_quote", + collapse_http_schemes=True, ): """Extracts the path info from the given URL (or WSGI environment) and path. The path info returned is a unicode string, not a bytestring @@ -395,29 +414,29 @@ def extract_path_info( .. versionadded:: 0.6 """ + def _normalize_netloc(scheme, netloc): - parts = netloc.split(u'@', 1)[-1].split(u':', 1) + parts = netloc.split(u"@", 1)[-1].split(u":", 1) if len(parts) == 2: netloc, port = parts - if (scheme == u'http' and port == u'80') or \ - (scheme == u'https' and port == u'443'): + if (scheme == u"http" and port == u"80") or ( + scheme == u"https" and port == u"443" + ): port = None else: netloc = parts[0] port = None if port is not None: - netloc += u':' + port + netloc += u":" + port return netloc # make sure whatever we are working on is a IRI and parse it path = uri_to_iri(path_or_url, charset, errors) if isinstance(environ_or_baseurl, dict): - environ_or_baseurl = get_current_url(environ_or_baseurl, - root_only=True) + environ_or_baseurl = get_current_url(environ_or_baseurl, root_only=True) base_iri = uri_to_iri(environ_or_baseurl, charset, errors) base_scheme, base_netloc, base_path = url_parse(base_iri)[:3] - cur_scheme, cur_netloc, cur_path, = \ - url_parse(url_join(base_iri, path))[:3] + cur_scheme, cur_netloc, cur_path, = url_parse(url_join(base_iri, path))[:3] # normalize the network location base_netloc = _normalize_netloc(base_scheme, base_netloc) @@ -426,11 +445,10 @@ def extract_path_info( # is that IRI even on a known HTTP scheme? if collapse_http_schemes: for scheme in base_scheme, cur_scheme: - if scheme not in (u'http', u'https'): + if scheme not in (u"http", u"https"): return None else: - if not (base_scheme in (u'http', u'https') - and base_scheme == cur_scheme): + if not (base_scheme in (u"http", u"https") and base_scheme == cur_scheme): return None # are the netlocs compatible? @@ -438,16 +456,15 @@ def extract_path_info( return None # are we below the application path? - base_path = base_path.rstrip(u'/') + base_path = base_path.rstrip(u"/") if not cur_path.startswith(base_path): return None - return u'/' + cur_path[len(base_path):].lstrip(u'/') + return u"/" + cur_path[len(base_path) :].lstrip(u"/") @implements_iterator class ClosingIterator(object): - """The WSGI specification requires that all middlewares and gateways respect the `close` callback of the iterable returned by the application. Because it is useful to add another close action to a returned iterable @@ -478,7 +495,7 @@ class ClosingIterator(object): callbacks = [callbacks] else: callbacks = list(callbacks) - iterable_close = getattr(iterable, 'close', None) + iterable_close = getattr(iterable, "close", None) if iterable_close: callbacks.insert(0, iterable_close) self._callbacks = callbacks @@ -510,12 +527,11 @@ def wrap_file(environ, file, buffer_size=8192): :param file: a :class:`file`-like object with a :meth:`~file.read` method. :param buffer_size: number of bytes for one iteration. """ - return environ.get('wsgi.file_wrapper', FileWrapper)(file, buffer_size) + return environ.get("wsgi.file_wrapper", FileWrapper)(file, buffer_size) @implements_iterator class FileWrapper(object): - """This class can be used to convert a :class:`file`-like object into an iterable. It yields `buffer_size` blocks until the file is fully read. @@ -538,22 +554,22 @@ class FileWrapper(object): self.buffer_size = buffer_size def close(self): - if hasattr(self.file, 'close'): + if hasattr(self.file, "close"): self.file.close() def seekable(self): - if hasattr(self.file, 'seekable'): + if hasattr(self.file, "seekable"): return self.file.seekable() - if hasattr(self.file, 'seek'): + if hasattr(self.file, "seek"): return True return False def seek(self, *args): - if hasattr(self.file, 'seek'): + if hasattr(self.file, "seek"): self.file.seek(*args) def tell(self): - if hasattr(self.file, 'tell'): + if hasattr(self.file, "tell"): return self.file.tell() return None @@ -593,7 +609,7 @@ class _RangeWrapper(object): if byte_range is not None: self.end_byte = self.start_byte + self.byte_range self.read_length = 0 - self.seekable = hasattr(iterable, 'seekable') and iterable.seekable() + self.seekable = hasattr(iterable, "seekable") and iterable.seekable() self.end_reached = False def __iter__(self): @@ -618,7 +634,7 @@ class _RangeWrapper(object): while self.read_length <= self.start_byte: chunk = self._next_chunk() if chunk is not None: - chunk = chunk[self.start_byte - self.read_length:] + chunk = chunk[self.start_byte - self.read_length :] contextual_read_length = self.start_byte return chunk, contextual_read_length @@ -633,7 +649,7 @@ class _RangeWrapper(object): chunk = self._next_chunk() if self.end_byte is not None and self.read_length >= self.end_byte: self.end_reached = True - return chunk[:self.end_byte - contextual_read_length] + return chunk[: self.end_byte - contextual_read_length] return chunk def __next__(self): @@ -644,16 +660,17 @@ class _RangeWrapper(object): raise StopIteration() def close(self): - if hasattr(self.iterable, 'close'): + if hasattr(self.iterable, "close"): self.iterable.close() def _make_chunk_iter(stream, limit, buffer_size): """Helper for the line and chunk iter functions.""" if isinstance(stream, (bytes, bytearray, text_type)): - raise TypeError('Passed a string or byte object instead of ' - 'true iterator or stream.') - if not hasattr(stream, 'read'): + raise TypeError( + "Passed a string or byte object instead of true iterator or stream." + ) + if not hasattr(stream, "read"): for item in stream: if item: yield item @@ -668,8 +685,7 @@ def _make_chunk_iter(stream, limit, buffer_size): yield item -def make_line_iter(stream, limit=None, buffer_size=10 * 1024, - cap_at_buffer=False): +def make_line_iter(stream, limit=None, buffer_size=10 * 1024, cap_at_buffer=False): """Safely iterates line-based over an input stream. If the input stream is not a :class:`LimitedStream` the `limit` parameter is mandatory. @@ -703,15 +719,15 @@ def make_line_iter(stream, limit=None, buffer_size=10 * 1024, """ _iter = _make_chunk_iter(stream, limit, buffer_size) - first_item = next(_iter, '') + first_item = next(_iter, "") if not first_item: return s = make_literal_wrapper(first_item) - empty = s('') - cr = s('\r') - lf = s('\n') - crlf = s('\r\n') + empty = s("") + cr = s("\r") + lf = s("\n") + crlf = s("\r\n") _iter = chain((first_item,), _iter) @@ -719,7 +735,7 @@ def make_line_iter(stream, limit=None, buffer_size=10 * 1024, _join = empty.join buffer = [] while 1: - new_data = next(_iter, '') + new_data = next(_iter, "") if not new_data: break new_buf = [] @@ -754,8 +770,9 @@ def make_line_iter(stream, limit=None, buffer_size=10 * 1024, yield previous -def make_chunk_iter(stream, separator, limit=None, buffer_size=10 * 1024, - cap_at_buffer=False): +def make_chunk_iter( + stream, separator, limit=None, buffer_size=10 * 1024, cap_at_buffer=False +): """Works like :func:`make_line_iter` but accepts a separator which divides chunks. If you want newline based processing you should use :func:`make_line_iter` instead as it @@ -782,23 +799,23 @@ def make_chunk_iter(stream, separator, limit=None, buffer_size=10 * 1024, """ _iter = _make_chunk_iter(stream, limit, buffer_size) - first_item = next(_iter, '') + first_item = next(_iter, "") if not first_item: return _iter = chain((first_item,), _iter) if isinstance(first_item, text_type): separator = to_unicode(separator) - _split = re.compile(r'(%s)' % re.escape(separator)).split - _join = u''.join + _split = re.compile(r"(%s)" % re.escape(separator)).split + _join = u"".join else: separator = to_bytes(separator) - _split = re.compile(b'(' + re.escape(separator) + b')').split - _join = b''.join + _split = re.compile(b"(" + re.escape(separator) + b")").split + _join = b"".join buffer = [] while 1: - new_data = next(_iter, '') + new_data = next(_iter, "") if not new_data: break chunks = _split(new_data) @@ -828,7 +845,6 @@ def make_chunk_iter(stream, separator, limit=None, buffer_size=10 * 1024, @implements_iterator class LimitedStream(io.IOBase): - """Wraps a stream so that it doesn't read more than n bytes. If the stream is exhausted and the caller tries to get more bytes from it :func:`on_exhausted` is called which by default returns an empty @@ -891,7 +907,8 @@ class LimitedStream(io.IOBase): the client went away. By default a :exc:`~werkzeug.exceptions.ClientDisconnected` exception is raised. """ - from werkzeug.exceptions import ClientDisconnected + from .exceptions import ClientDisconnected + raise ClientDisconnected() def exhaust(self, chunk_size=1024 * 64): @@ -984,9 +1001,10 @@ class LimitedStream(io.IOBase): return True -from werkzeug.middleware.dispatcher import DispatcherMiddleware as _DispatcherMiddleware -from werkzeug.middleware.http_proxy import ProxyMiddleware as _ProxyMiddleware -from werkzeug.middleware.shared_data import SharedDataMiddleware as _SharedDataMiddleware +# DEPRECATED +from .middleware.dispatcher import DispatcherMiddleware as _DispatcherMiddleware +from .middleware.http_proxy import ProxyMiddleware as _ProxyMiddleware +from .middleware.shared_data import SharedDataMiddleware as _SharedDataMiddleware class ProxyMiddleware(_ProxyMiddleware): diff --git a/tests/conftest.py b/tests/conftest.py index 0a76d337..1ce4fd53 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,17 +8,15 @@ """ from __future__ import print_function -from itertools import count - import logging import os import platform import signal -import sys - import subprocess +import sys import textwrap import time +from itertools import count import pytest @@ -28,11 +26,12 @@ from werkzeug.urls import url_quote from werkzeug.utils import cached_property try: - __import__('pytest_xprocess') + __import__("pytest_xprocess") except ImportError: - @pytest.fixture(scope='session') + + @pytest.fixture(scope="session") def xprocess(): - pytest.skip('pytest-xprocess not installed.') + pytest.skip("pytest-xprocess not installed.") port_generator = count(13220) @@ -40,7 +39,7 @@ port_generator = count(13220) def _patch_reloader_loop(): def f(x): - print('reloader loop finished') + print("reloader loop finished") # Need to flush for some reason even though xprocess opens the # subprocess' stdout in unbuffered mode. # flush=True makes the test fail on py2, so flush manually @@ -48,6 +47,7 @@ def _patch_reloader_loop(): return time.sleep(x) import werkzeug._reloader + werkzeug._reloader.ReloaderLoop._sleep = staticmethod(f) @@ -59,11 +59,12 @@ pid_logger.addHandler(pid_handler) def _get_pid_middleware(f): def inner(environ, start_response): - if environ['PATH_INFO'] == '/_getpid': - start_response('200 OK', [('Content-Type', 'text/plain')]) + if environ["PATH_INFO"] == "/_getpid": + start_response("200 OK", [("Content-Type", "text/plain")]) pid_logger.info("pid=%s", os.getpid()) return [to_bytes(str(os.getpid()))] return f(environ, start_response) + return inner @@ -71,6 +72,7 @@ def _dev_server(): _patch_reloader_loop() sys.path.insert(0, sys.argv[1]) import testsuite_app + app = _get_pid_middleware(testsuite_app.app) serving.run_simple(application=app, **testsuite_app.kwargs) @@ -90,10 +92,10 @@ class _ServerInfo(object): @cached_property def logfile(self): - return self.xprocess.getinfo('dev_server').logpath.open() + return self.xprocess.getinfo("dev_server").logpath.open() def request_pid(self): - if self.url.startswith('http+unix://'): + if self.url.startswith("http+unix://"): from requests_unixsocket import get as rget else: from requests import get as rget @@ -101,7 +103,7 @@ class _ServerInfo(object): for i in range(10): time.sleep(0.1 * i) try: - response = rget(self.url + '/_getpid', verify=False) + response = rget(self.url + "/_getpid", verify=False) self.last_pid = int(response.text) return self.last_pid except Exception as e: # urllib also raises socketerrors @@ -114,51 +116,58 @@ class _ServerInfo(object): time.sleep(0.1 * i) new_pid = self.request_pid() if not new_pid: - raise RuntimeError('Server is down.') + raise RuntimeError("Server is down.") if new_pid != old_pid: return - raise RuntimeError('Server did not reload.') + raise RuntimeError("Server did not reload.") def wait_for_reloader_loop(self): for i in range(20): time.sleep(0.1 * i) line = self.logfile.readline() - if 'reloader loop finished' in line: + if "reloader loop finished" in line: return @pytest.fixture def dev_server(tmpdir, xprocess, request, monkeypatch): - '''Run werkzeug.serving.run_simple in its own process. + """Run werkzeug.serving.run_simple in its own process. :param application: String for the module that will be created. The module must have a global ``app`` object, a ``kwargs`` dict is also available whose values will be passed to ``run_simple``. - ''' + """ + def run_dev_server(application): - app_pkg = tmpdir.mkdir('testsuite_app') - appfile = app_pkg.join('__init__.py') + app_pkg = tmpdir.mkdir("testsuite_app") + appfile = app_pkg.join("__init__.py") port = next(port_generator) - appfile.write('\n\n'.join(( - "kwargs = {{'hostname': 'localhost', 'port': {port:d}}}".format( - port=port), - textwrap.dedent(application) - ))) - - monkeypatch.delitem(sys.modules, 'testsuite_app', raising=False) + appfile.write( + "\n\n".join( + ( + "kwargs = {{'hostname': 'localhost', 'port': {port:d}}}".format( + port=port + ), + textwrap.dedent(application), + ) + ) + ) + + monkeypatch.delitem(sys.modules, "testsuite_app", raising=False) monkeypatch.syspath_prepend(str(tmpdir)) import testsuite_app - hostname = testsuite_app.kwargs['hostname'] - port = testsuite_app.kwargs['port'] - addr = '{}:{}'.format(hostname, port) - - if hostname.startswith('unix://'): - addr = hostname.split('unix://', 1)[1] - requests_url = 'http+unix://' + url_quote(addr, safe='') - elif testsuite_app.kwargs.get('ssl_context', None): - requests_url = 'https://localhost:{0}'.format(port) + + hostname = testsuite_app.kwargs["hostname"] + port = testsuite_app.kwargs["port"] + addr = "{}:{}".format(hostname, port) + + if hostname.startswith("unix://"): + addr = hostname.split("unix://", 1)[1] + requests_url = "http+unix://" + url_quote(addr, safe="") + elif testsuite_app.kwargs.get("ssl_context", None): + requests_url = "https://localhost:{0}".format(port) else: - requests_url = 'http://localhost:{0}'.format(port) + requests_url = "http://localhost:{0}".format(port) info = _ServerInfo(xprocess, addr, requests_url, port) @@ -171,7 +180,7 @@ def dev_server(tmpdir, xprocess, request, monkeypatch): def pattern(self): return "pid=%s" % info.request_pid() - xprocess.ensure('dev_server', Starter, restart=True) + xprocess.ensure("dev_server", Starter, restart=True) @request.addfinalizer def teardown(): @@ -191,5 +200,5 @@ def dev_server(tmpdir, xprocess, request, monkeypatch): return run_dev_server -if __name__ == '__main__': +if __name__ == "__main__": _dev_server() diff --git a/tests/contrib/cache/conftest.py b/tests/contrib/cache/conftest.py index cb651098..655f0fa6 100644 --- a/tests/contrib/cache/conftest.py +++ b/tests/contrib/cache/conftest.py @@ -4,10 +4,7 @@ import pytest # build the path to the uwsgi marker file # when running in tox, this will be relative to the tox env -filename = os.path.join( - os.environ.get('TOX_ENVTMPDIR', ''), - 'test_uwsgi_failed' -) +filename = os.path.join(os.environ.get("TOX_ENVTMPDIR", ""), "test_uwsgi_failed") @pytest.hookimpl(tryfirst=True, hookwrapper=True) @@ -20,9 +17,9 @@ def pytest_runtest_makereport(item, call): outcome = yield report = outcome.get_result() - if item.cls.__name__ != 'TestUWSGICache': + if item.cls.__name__ != "TestUWSGICache": return if report.failed: - with open(filename, 'a') as f: - f.write(item.name + '\n') + with open(filename, "a") as f: + f.write(item.name + "\n") diff --git a/tests/contrib/cache/test_cache.py b/tests/contrib/cache/test_cache.py index 0fcc5e46..c13227e3 100644 --- a/tests/contrib/cache/test_cache.py +++ b/tests/contrib/cache/test_cache.py @@ -12,8 +12,6 @@ import errno import pytest -pytestmark = pytest.mark.skip('werkzeug.contrib.cache moved to cachelib') - from werkzeug._compat import text_type from werkzeug.contrib import cache @@ -33,6 +31,8 @@ except ImportError: except ImportError: memcache = None +pytestmark = pytest.mark.skip("werkzeug.contrib.cache moved to cachelib") + class CacheTestsBase(object): _can_use_fast_sleep = True @@ -41,13 +41,15 @@ class CacheTestsBase(object): @pytest.fixture def fast_sleep(self, monkeypatch): if self._can_use_fast_sleep: + def sleep(delta): orig_time = cache.time - monkeypatch.setattr(cache, 'time', lambda: orig_time() + delta) + monkeypatch.setattr(cache, "time", lambda: orig_time() + delta) return sleep else: import time + return time.sleep @pytest.fixture @@ -63,13 +65,13 @@ class CacheTestsBase(object): class GenericCacheTests(CacheTestsBase): def test_generic_get_dict(self, c): - assert c.set('a', 'a') - assert c.set('b', 'b') - d = c.get_dict('a', 'b') - assert 'a' in d - assert 'a' == d['a'] - assert 'b' in d - assert 'b' == d['b'] + assert c.set("a", "a") + assert c.set("b", "b") + d = c.get_dict("a", "b") + assert "a" in d + assert "a" == d["a"] + assert "b" in d + assert "b" == d["b"] def test_generic_set_get(self, c): for i in range(3): @@ -80,71 +82,71 @@ class GenericCacheTests(CacheTestsBase): assert result == i * i, result def test_generic_get_set(self, c): - assert c.set('foo', ['bar']) - assert c.get('foo') == ['bar'] + assert c.set("foo", ["bar"]) + assert c.get("foo") == ["bar"] def test_generic_get_many(self, c): - assert c.set('foo', ['bar']) - assert c.set('spam', 'eggs') - assert c.get_many('foo', 'spam') == [['bar'], 'eggs'] + assert c.set("foo", ["bar"]) + assert c.set("spam", "eggs") + assert c.get_many("foo", "spam") == [["bar"], "eggs"] def test_generic_set_many(self, c): - assert c.set_many({'foo': 'bar', 'spam': ['eggs']}) - assert c.get('foo') == 'bar' - assert c.get('spam') == ['eggs'] + assert c.set_many({"foo": "bar", "spam": ["eggs"]}) + assert c.get("foo") == "bar" + assert c.get("spam") == ["eggs"] def test_generic_add(self, c): # sanity check that add() works like set() - assert c.add('foo', 'bar') - assert c.get('foo') == 'bar' - assert not c.add('foo', 'qux') - assert c.get('foo') == 'bar' + assert c.add("foo", "bar") + assert c.get("foo") == "bar" + assert not c.add("foo", "qux") + assert c.get("foo") == "bar" def test_generic_delete(self, c): - assert c.add('foo', 'bar') - assert c.get('foo') == 'bar' - assert c.delete('foo') - assert c.get('foo') is None + assert c.add("foo", "bar") + assert c.get("foo") == "bar" + assert c.delete("foo") + assert c.get("foo") is None def test_generic_delete_many(self, c): - assert c.add('foo', 'bar') - assert c.add('spam', 'eggs') - assert c.delete_many('foo', 'spam') - assert c.get('foo') is None - assert c.get('spam') is None + assert c.add("foo", "bar") + assert c.add("spam", "eggs") + assert c.delete_many("foo", "spam") + assert c.get("foo") is None + assert c.get("spam") is None def test_generic_inc_dec(self, c): - assert c.set('foo', 1) - assert c.inc('foo') == c.get('foo') == 2 - assert c.dec('foo') == c.get('foo') == 1 - assert c.delete('foo') + assert c.set("foo", 1) + assert c.inc("foo") == c.get("foo") == 2 + assert c.dec("foo") == c.get("foo") == 1 + assert c.delete("foo") def test_generic_true_false(self, c): - assert c.set('foo', True) - assert c.get('foo') in (True, 1) - assert c.set('bar', False) - assert c.get('bar') in (False, 0) + assert c.set("foo", True) + assert c.get("foo") in (True, 1) + assert c.set("bar", False) + assert c.get("bar") in (False, 0) def test_generic_timeout(self, c, fast_sleep): - c.set('foo', 'bar', 0) - assert c.get('foo') == 'bar' - c.set('baz', 'qux', 1) - assert c.get('baz') == 'qux' + c.set("foo", "bar", 0) + assert c.get("foo") == "bar" + c.set("baz", "qux", 1) + assert c.get("baz") == "qux" fast_sleep(3) # timeout of zero means no timeout - assert c.get('foo') == 'bar' + assert c.get("foo") == "bar" if self._guaranteed_deletes: - assert c.get('baz') is None + assert c.get("baz") is None def test_generic_has(self, c): - assert c.has('foo') in (False, 0) - assert c.has('spam') in (False, 0) - assert c.set('foo', 'bar') - assert c.has('foo') in (True, 1) - assert c.has('spam') in (False, 0) - c.delete('foo') - assert c.has('foo') in (False, 0) - assert c.has('spam') in (False, 0) + assert c.has("foo") in (False, 0) + assert c.has("spam") in (False, 0) + assert c.set("foo", "bar") + assert c.has("foo") in (True, 1) + assert c.has("spam") in (False, 0) + c.delete("foo") + assert c.has("foo") in (False, 0) + assert c.has("spam") in (False, 0) class TestSimpleCache(GenericCacheTests): @@ -154,10 +156,10 @@ class TestSimpleCache(GenericCacheTests): def test_purge(self): c = cache.SimpleCache(threshold=2) - c.set('a', 'a') - c.set('b', 'b') - c.set('c', 'c') - c.set('d', 'd') + c.set("a", "a") + c.set("b", "b") + c.set("c", "c") + c.set("d", "d") # Cache purges old items *before* it sets new ones. assert len(c._cache) == 3 @@ -178,7 +180,7 @@ class TestFileSystemCache(GenericCacheTests): assert nof_cache_files <= THRESHOLD def test_filesystemcache_clear(self, c): - assert c.set('foo', 'bar') + assert c.set("foo", "bar") nof_cache_files = c.get(c._fs_count_file) assert nof_cache_files == 1 assert c.clear() @@ -202,13 +204,13 @@ class TestFileSystemCache(GenericCacheTests): assert nof_cache_files is None def test_count_file_accuracy(self, c): - assert c.set('foo', 'bar') - assert c.set('moo', 'car') - c.add('moo', 'tar') + assert c.set("foo", "bar") + assert c.set("moo", "car") + c.add("moo", "tar") assert c.get(c._fs_count_file) == 2 - assert c.add('too', 'far') + assert c.add("too", "far") assert c.get(c._fs_count_file) == 3 - assert c.delete('moo') + assert c.delete("moo") assert c.get(c._fs_count_file) == 2 assert c.clear() assert c.get(c._fs_count_file) == 0 @@ -220,124 +222,121 @@ class TestFileSystemCache(GenericCacheTests): class TestRedisCache(GenericCacheTests): _can_use_fast_sleep = False - @pytest.fixture(scope='class', autouse=True) + @pytest.fixture(scope="class", autouse=True) def requirements(self, xprocess): if redis is None: pytest.skip('Python package "redis" is not installed.') def prepare(cwd): - return '[Rr]eady to accept connections', ['redis-server'] + return "[Rr]eady to accept connections", ["redis-server"] try: - xprocess.ensure('redis_server', prepare) + xprocess.ensure("redis_server", prepare) except IOError as e: # xprocess raises FileNotFoundError if e.errno == errno.ENOENT: - pytest.skip('Redis is not installed.') + pytest.skip("Redis is not installed.") else: raise yield - xprocess.getinfo('redis_server').terminate() + xprocess.getinfo("redis_server").terminate() @pytest.fixture(params=(None, False, True)) def make_cache(self, request): if request.param is None: - host = 'localhost' + host = "localhost" elif request.param: host = redis.StrictRedis() else: host = redis.Redis() - c = cache.RedisCache( - host=host, - key_prefix='werkzeug-test-case:', - ) + c = cache.RedisCache(host=host, key_prefix="werkzeug-test-case:") yield lambda: c c.clear() def test_compat(self, c): - assert c._client.set(c.key_prefix + 'foo', 'Awesome') - assert c.get('foo') == b'Awesome' - assert c._client.set(c.key_prefix + 'foo', '42') - assert c.get('foo') == 42 + assert c._client.set(c.key_prefix + "foo", "Awesome") + assert c.get("foo") == b"Awesome" + assert c._client.set(c.key_prefix + "foo", "42") + assert c.get("foo") == 42 def test_empty_host(self): with pytest.raises(ValueError) as exc_info: cache.RedisCache(host=None) - assert text_type(exc_info.value) == 'RedisCache host parameter may not be None' + assert text_type(exc_info.value) == "RedisCache host parameter may not be None" class TestMemcachedCache(GenericCacheTests): _can_use_fast_sleep = False _guaranteed_deletes = False - @pytest.fixture(scope='class', autouse=True) + @pytest.fixture(scope="class", autouse=True) def requirements(self, xprocess): if memcache is None: pytest.skip( - 'Python package for memcache is not installed. Need one of ' + "Python package for memcache is not installed. Need one of " '"pylibmc", "google.appengine", or "memcache".' ) def prepare(cwd): - return '', ['memcached'] + return "", ["memcached"] try: - xprocess.ensure('memcached', prepare) + xprocess.ensure("memcached", prepare) except IOError as e: # xprocess raises FileNotFoundError if e.errno == errno.ENOENT: - pytest.skip('Memcached is not installed.') + pytest.skip("Memcached is not installed.") else: raise yield - xprocess.getinfo('memcached').terminate() + xprocess.getinfo("memcached").terminate() @pytest.fixture def make_cache(self): - c = cache.MemcachedCache(key_prefix='werkzeug-test-case:') + c = cache.MemcachedCache(key_prefix="werkzeug-test-case:") yield lambda: c c.clear() def test_compat(self, c): - assert c._client.set(c.key_prefix + 'foo', 'bar') - assert c.get('foo') == 'bar' + assert c._client.set(c.key_prefix + "foo", "bar") + assert c.get("foo") == "bar" def test_huge_timeouts(self, c): # Timeouts greater than epoch are interpreted as POSIX timestamps # (i.e. not relative to now, but relative to epoch) epoch = 2592000 - c.set('foo', 'bar', epoch + 100) - assert c.get('foo') == 'bar' + c.set("foo", "bar", epoch + 100) + assert c.get("foo") == "bar" class TestUWSGICache(GenericCacheTests): _can_use_fast_sleep = False _guaranteed_deletes = False - @pytest.fixture(scope='class', autouse=True) + @pytest.fixture(scope="class", autouse=True) def requirements(self): try: import uwsgi # NOQA except ImportError: pytest.skip( 'Python "uwsgi" package is only avaialable when running ' - 'inside uWSGI.' + "inside uWSGI." ) @pytest.fixture def make_cache(self): - c = cache.UWSGICache(cache='werkzeugtest') + c = cache.UWSGICache(cache="werkzeugtest") yield lambda: c c.clear() class TestNullCache(CacheTestsBase): - @pytest.fixture(scope='class', autouse=True) + @pytest.fixture(scope="class", autouse=True) def make_cache(self): return cache.NullCache def test_has(self, c): - assert not c.has('foo') + assert not c.has("foo") diff --git a/tests/contrib/test_atom.py b/tests/contrib/test_atom.py index 4117d571..5a10556e 100644 --- a/tests/contrib/test_atom.py +++ b/tests/contrib/test_atom.py @@ -9,9 +9,12 @@ :license: BSD-3-Clause """ import datetime + import pytest -from werkzeug.contrib.atom import format_iso8601, AtomFeed, FeedEntry +from werkzeug.contrib.atom import AtomFeed +from werkzeug.contrib.atom import FeedEntry +from werkzeug.contrib.atom import format_iso8601 class TestAtomFeed(object): @@ -25,26 +28,25 @@ class TestAtomFeed(object): def test_atom_title_no_id(self): with pytest.raises(ValueError): - AtomFeed(title='test_title') + AtomFeed(title="test_title") def test_atom_add_one(self): - a = AtomFeed(title='test_title', id=1) - f = FeedEntry( - title='test_title', id=1, updated=datetime.datetime.now()) + a = AtomFeed(title="test_title", id=1) + f = FeedEntry(title="test_title", id=1, updated=datetime.datetime.now()) assert len(a.entries) == 0 a.add(f) assert len(a.entries) == 1 def test_atom_add_one_kwargs(self): - a = AtomFeed(title='test_title', id=1) + a = AtomFeed(title="test_title", id=1) assert len(a.entries) == 0 - a.add(title='test_title', id=1, updated=datetime.datetime.now()) + a.add(title="test_title", id=1, updated=datetime.datetime.now()) assert len(a.entries) == 1 assert isinstance(a.entries[0], FeedEntry) def test_atom_to_str(self): updated_time = datetime.datetime.now() - expected_repr = ''' + expected_repr = """ <?xml version="1.0" encoding="utf-8"?> <feed xmlns="http://www.w3.org/2005/Atom"> <title type="text">test_title</title> @@ -52,10 +54,11 @@ class TestAtomFeed(object): <updated>%s</updated> <generator>Werkzeug</generator> </feed> - ''' % format_iso8601(updated_time) - a = AtomFeed(title='test_title', id=1, updated=updated_time) - assert str(a).strip().replace(' ', '') == \ - expected_repr.strip().replace(' ', '') + """ % format_iso8601( + updated_time + ) + a = AtomFeed(title="test_title", id=1, updated=updated_time) + assert str(a).strip().replace(" ", "") == expected_repr.strip().replace(" ", "") class TestFeedEntry(object): @@ -69,35 +72,38 @@ class TestFeedEntry(object): def test_feed_entry_no_id(self): with pytest.raises(ValueError): - FeedEntry(title='test_title') + FeedEntry(title="test_title") def test_feed_entry_no_updated(self): with pytest.raises(ValueError): - FeedEntry(title='test_title', id=1) + FeedEntry(title="test_title", id=1) def test_feed_entry_to_str(self): updated_time = datetime.datetime.now() - expected_feed_entry_str = ''' + expected_feed_entry_str = """ <entry> <title type="text">test_title</title> <id>1</id> <updated>%s</updated> </entry> - ''' % format_iso8601(updated_time) + """ % format_iso8601( + updated_time + ) - f = FeedEntry(title='test_title', id=1, updated=updated_time) - assert str(f).strip().replace(' ', '') == \ - expected_feed_entry_str.strip().replace(' ', '') + f = FeedEntry(title="test_title", id=1, updated=updated_time) + assert str(f).strip().replace( + " ", "" + ) == expected_feed_entry_str.strip().replace(" ", "") def test_format_iso8601(): # naive datetime should be treated as utc dt = datetime.datetime(2014, 8, 31, 2, 5, 6) - assert format_iso8601(dt) == '2014-08-31T02:05:06Z' + assert format_iso8601(dt) == "2014-08-31T02:05:06Z" # tz-aware datetime dt = datetime.datetime(2014, 8, 31, 11, 5, 6, tzinfo=KST()) - assert format_iso8601(dt) == '2014-08-31T11:05:06+09:00' + assert format_iso8601(dt) == "2014-08-31T11:05:06+09:00" class KST(datetime.tzinfo): @@ -108,7 +114,7 @@ class KST(datetime.tzinfo): return datetime.timedelta(hours=9) def tzname(self, dt): - return 'KST' + return "KST" def dst(self, dt): return datetime.timedelta(0) diff --git a/tests/contrib/test_fixers.py b/tests/contrib/test_fixers.py index 79e5cc7f..2777f611 100644 --- a/tests/contrib/test_fixers.py +++ b/tests/contrib/test_fixers.py @@ -9,80 +9,78 @@ from werkzeug.wrappers import Response @Request.application def path_check_app(request): - return Response('PATH_INFO: %s\nSCRIPT_NAME: %s' % ( - request.environ.get('PATH_INFO', ''), - request.environ.get('SCRIPT_NAME', '') - )) + return Response( + "PATH_INFO: %s\nSCRIPT_NAME: %s" + % (request.environ.get("PATH_INFO", ""), request.environ.get("SCRIPT_NAME", "")) + ) class TestServerFixer(object): - def test_cgi_root_fix(self): app = fixers.CGIRootFix(path_check_app) response = Response.from_app( - app, - dict(create_environ(), - SCRIPT_NAME='/foo', - PATH_INFO='/bar')) - assert response.get_data() == b'PATH_INFO: /bar\nSCRIPT_NAME: ' + app, dict(create_environ(), SCRIPT_NAME="/foo", PATH_INFO="/bar") + ) + assert response.get_data() == b"PATH_INFO: /bar\nSCRIPT_NAME: " def test_cgi_root_fix_custom_app_root(self): - app = fixers.CGIRootFix(path_check_app, app_root='/baz/') + app = fixers.CGIRootFix(path_check_app, app_root="/baz/") response = Response.from_app( - app, - dict(create_environ(), - SCRIPT_NAME='/foo', - PATH_INFO='/bar')) - assert response.get_data() == b'PATH_INFO: /bar\nSCRIPT_NAME: baz' + app, dict(create_environ(), SCRIPT_NAME="/foo", PATH_INFO="/bar") + ) + assert response.get_data() == b"PATH_INFO: /bar\nSCRIPT_NAME: baz" def test_path_info_from_request_uri_fix(self): app = fixers.PathInfoFromRequestUriFix(path_check_app) - for key in 'REQUEST_URI', 'REQUEST_URL', 'UNENCODED_URL': - env = dict(create_environ(), SCRIPT_NAME='/test', PATH_INFO='/?????') - env[key] = '/test/foo%25bar?drop=this' + for key in "REQUEST_URI", "REQUEST_URL", "UNENCODED_URL": + env = dict(create_environ(), SCRIPT_NAME="/test", PATH_INFO="/?????") + env[key] = "/test/foo%25bar?drop=this" response = Response.from_app(app, env) - assert response.get_data() == b'PATH_INFO: /foo%bar\nSCRIPT_NAME: /test' + assert response.get_data() == b"PATH_INFO: /foo%bar\nSCRIPT_NAME: /test" def test_header_rewriter_fix(self): @Request.application def application(request): - return Response("", headers=[ - ('X-Foo', 'bar') - ]) - application = fixers.HeaderRewriterFix(application, ('X-Foo',), (('X-Bar', '42'),)) + return Response("", headers=[("X-Foo", "bar")]) + + application = fixers.HeaderRewriterFix( + application, ("X-Foo",), (("X-Bar", "42"),) + ) response = Response.from_app(application, create_environ()) - assert response.headers['Content-Type'] == 'text/plain; charset=utf-8' - assert 'X-Foo' not in response.headers - assert response.headers['X-Bar'] == '42' + assert response.headers["Content-Type"] == "text/plain; charset=utf-8" + assert "X-Foo" not in response.headers + assert response.headers["X-Bar"] == "42" class TestBrowserFixer(object): - def test_ie_fixes(self): @fixers.InternetExplorerFix @Request.application def application(request): - response = Response('binary data here', mimetype='application/vnd.ms-excel') - response.headers['Vary'] = 'Cookie' - response.headers['Content-Disposition'] = 'attachment; filename=foo.xls' + response = Response("binary data here", mimetype="application/vnd.ms-excel") + response.headers["Vary"] = "Cookie" + response.headers["Content-Disposition"] = "attachment; filename=foo.xls" return response c = Client(application, Response) - response = c.get('/', headers=[ - ('User-Agent', 'Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 6.0)') - ]) + response = c.get( + "/", + headers=[ + ("User-Agent", "Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 6.0)") + ], + ) # IE gets no vary - assert response.get_data() == b'binary data here' - assert 'vary' not in response.headers - assert response.headers['content-disposition'] == 'attachment; filename=foo.xls' - assert response.headers['content-type'] == 'application/vnd.ms-excel' + assert response.get_data() == b"binary data here" + assert "vary" not in response.headers + assert response.headers["content-disposition"] == "attachment; filename=foo.xls" + assert response.headers["content-type"] == "application/vnd.ms-excel" # other browsers do c = Client(application, Response) - response = c.get('/') - assert response.get_data() == b'binary data here' - assert 'vary' in response.headers + response = c.get("/") + assert response.get_data() == b"binary data here" + assert "vary" in response.headers cc = ResponseCacheControl() cc.no_cache = True @@ -90,40 +88,47 @@ class TestBrowserFixer(object): @fixers.InternetExplorerFix @Request.application def application(request): - response = Response('binary data here', mimetype='application/vnd.ms-excel') - response.headers['Pragma'] = ', '.join(pragma) - response.headers['Cache-Control'] = cc.to_header() - response.headers['Content-Disposition'] = 'attachment; filename=foo.xls' + response = Response("binary data here", mimetype="application/vnd.ms-excel") + response.headers["Pragma"] = ", ".join(pragma) + response.headers["Cache-Control"] = cc.to_header() + response.headers["Content-Disposition"] = "attachment; filename=foo.xls" return response # IE has no pragma or cache control - pragma = ('no-cache',) + pragma = ("no-cache",) c = Client(application, Response) - response = c.get('/', headers=[ - ('User-Agent', 'Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 6.0)') - ]) - assert response.get_data() == b'binary data here' - assert 'pragma' not in response.headers - assert 'cache-control' not in response.headers - assert response.headers['content-disposition'] == 'attachment; filename=foo.xls' + response = c.get( + "/", + headers=[ + ("User-Agent", "Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 6.0)") + ], + ) + assert response.get_data() == b"binary data here" + assert "pragma" not in response.headers + assert "cache-control" not in response.headers + assert response.headers["content-disposition"] == "attachment; filename=foo.xls" # IE has simplified pragma - pragma = ('no-cache', 'x-foo') + pragma = ("no-cache", "x-foo") cc.proxy_revalidate = True - response = c.get('/', headers=[ - ('User-Agent', 'Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 6.0)') - ]) - assert response.get_data() == b'binary data here' - assert response.headers['pragma'] == 'x-foo' - assert response.headers['cache-control'] == 'proxy-revalidate' - assert response.headers['content-disposition'] == 'attachment; filename=foo.xls' + response = c.get( + "/", + headers=[ + ("User-Agent", "Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 6.0)") + ], + ) + assert response.get_data() == b"binary data here" + assert response.headers["pragma"] == "x-foo" + assert response.headers["cache-control"] == "proxy-revalidate" + assert response.headers["content-disposition"] == "attachment; filename=foo.xls" # regular browsers get everything - response = c.get('/') - assert response.get_data() == b'binary data here' - assert response.headers['pragma'] == 'no-cache, x-foo' - cc = parse_cache_control_header(response.headers['cache-control'], - cls=ResponseCacheControl) + response = c.get("/") + assert response.get_data() == b"binary data here" + assert response.headers["pragma"] == "no-cache, x-foo" + cc = parse_cache_control_header( + response.headers["cache-control"], cls=ResponseCacheControl + ) assert cc.no_cache assert cc.proxy_revalidate - assert response.headers['content-disposition'] == 'attachment; filename=foo.xls' + assert response.headers["content-disposition"] == "attachment; filename=foo.xls" diff --git a/tests/contrib/test_iterio.py b/tests/contrib/test_iterio.py index 00a22d88..e2c48130 100644 --- a/tests/contrib/test_iterio.py +++ b/tests/contrib/test_iterio.py @@ -10,12 +10,12 @@ """ import pytest -from tests import strict_eq -from werkzeug.contrib.iterio import IterIO, greenlet +from .. import strict_eq +from werkzeug.contrib.iterio import greenlet +from werkzeug.contrib.iterio import IterIO class TestIterO(object): - def test_basic_native(self): io = IterIO(["Hello", "World", "1", "2", "3"]) io.seek(0) @@ -34,24 +34,24 @@ class TestIterO(object): assert io.closed io = IterIO(["Hello\n", "World!"]) - assert io.readline() == 'Hello\n' - assert io._buf == 'Hello\n' - assert io.read() == 'World!' - assert io._buf == 'Hello\nWorld!' + assert io.readline() == "Hello\n" + assert io._buf == "Hello\n" + assert io.read() == "World!" + assert io._buf == "Hello\nWorld!" assert io.tell() == 12 io.seek(0) - assert io.readlines() == ['Hello\n', 'World!'] + assert io.readlines() == ["Hello\n", "World!"] - io = IterIO(['Line one\nLine ', 'two\nLine three']) - assert list(io) == ['Line one\n', 'Line two\n', 'Line three'] - io = IterIO(iter('Line one\nLine two\nLine three')) - assert list(io) == ['Line one\n', 'Line two\n', 'Line three'] - io = IterIO(['Line one\nL', 'ine', ' two', '\nLine three']) - assert list(io) == ['Line one\n', 'Line two\n', 'Line three'] + io = IterIO(["Line one\nLine ", "two\nLine three"]) + assert list(io) == ["Line one\n", "Line two\n", "Line three"] + io = IterIO(iter("Line one\nLine two\nLine three")) + assert list(io) == ["Line one\n", "Line two\n", "Line three"] + io = IterIO(["Line one\nL", "ine", " two", "\nLine three"]) + assert list(io) == ["Line one\n", "Line two\n", "Line three"] io = IterIO(["foo\n", "bar"]) io.seek(-4, 2) - assert io.read(4) == '\nbar' + assert io.read(4) == "\nbar" pytest.raises(IOError, io.seek, 2, 100) io.close() @@ -74,17 +74,17 @@ class TestIterO(object): assert io.closed io = IterIO([b"Hello\n", b"World!"]) - assert io.readline() == b'Hello\n' - assert io._buf == b'Hello\n' - assert io.read() == b'World!' - assert io._buf == b'Hello\nWorld!' + assert io.readline() == b"Hello\n" + assert io._buf == b"Hello\n" + assert io.read() == b"World!" + assert io._buf == b"Hello\nWorld!" assert io.tell() == 12 io.seek(0) - assert io.readlines() == [b'Hello\n', b'World!'] + assert io.readlines() == [b"Hello\n", b"World!"] io = IterIO([b"foo\n", b"bar"]) io.seek(-4, 2) - assert io.read(4) == b'\nbar' + assert io.read(4) == b"\nbar" pytest.raises(IOError, io.seek, 2, 100) io.close() @@ -107,17 +107,17 @@ class TestIterO(object): assert io.closed io = IterIO([u"Hello\n", u"World!"]) - assert io.readline() == u'Hello\n' - assert io._buf == u'Hello\n' - assert io.read() == u'World!' - assert io._buf == u'Hello\nWorld!' + assert io.readline() == u"Hello\n" + assert io._buf == u"Hello\n" + assert io.read() == u"World!" + assert io._buf == u"Hello\nWorld!" assert io.tell() == 12 io.seek(0) - assert io.readlines() == [u'Hello\n', u'World!'] + assert io.readlines() == [u"Hello\n", u"World!"] io = IterIO([u"foo\n", u"bar"]) io.seek(-4, 2) - assert io.read(4) == u'\nbar' + assert io.read(4) == u"\nbar" pytest.raises(IOError, io.seek, 2, 100) io.close() @@ -125,60 +125,62 @@ class TestIterO(object): def test_sentinel_cases(self): io = IterIO([]) - strict_eq(io.read(), '') - io = IterIO([], b'') - strict_eq(io.read(), b'') - io = IterIO([], u'') - strict_eq(io.read(), u'') + strict_eq(io.read(), "") + io = IterIO([], b"") + strict_eq(io.read(), b"") + io = IterIO([], u"") + strict_eq(io.read(), u"") io = IterIO([]) - strict_eq(io.read(), '') - io = IterIO([b'']) - strict_eq(io.read(), b'') - io = IterIO([u'']) - strict_eq(io.read(), u'') + strict_eq(io.read(), "") + io = IterIO([b""]) + strict_eq(io.read(), b"") + io = IterIO([u""]) + strict_eq(io.read(), u"") io = IterIO([]) - strict_eq(io.readline(), '') - io = IterIO([], b'') - strict_eq(io.readline(), b'') - io = IterIO([], u'') - strict_eq(io.readline(), u'') + strict_eq(io.readline(), "") + io = IterIO([], b"") + strict_eq(io.readline(), b"") + io = IterIO([], u"") + strict_eq(io.readline(), u"") io = IterIO([]) - strict_eq(io.readline(), '') - io = IterIO([b'']) - strict_eq(io.readline(), b'') - io = IterIO([u'']) - strict_eq(io.readline(), u'') + strict_eq(io.readline(), "") + io = IterIO([b""]) + strict_eq(io.readline(), b"") + io = IterIO([u""]) + strict_eq(io.readline(), u"") -@pytest.mark.skipif(greenlet is None, reason='Greenlet is not installed.') +@pytest.mark.skipif(greenlet is None, reason="Greenlet is not installed.") class TestIterI(object): - def test_basic(self): def producer(out): - out.write('1\n') - out.write('2\n') + out.write("1\n") + out.write("2\n") out.flush() - out.write('3\n') + out.write("3\n") + iterable = IterIO(producer) - assert next(iterable) == '1\n2\n' - assert next(iterable) == '3\n' + assert next(iterable) == "1\n2\n" + assert next(iterable) == "3\n" pytest.raises(StopIteration, next, iterable) def test_sentinel_cases(self): def producer_dummy_flush(out): out.flush() + iterable = IterIO(producer_dummy_flush) - strict_eq(next(iterable), '') + strict_eq(next(iterable), "") def producer_empty(out): pass + iterable = IterIO(producer_empty) pytest.raises(StopIteration, next, iterable) - iterable = IterIO(producer_dummy_flush, b'') - strict_eq(next(iterable), b'') - iterable = IterIO(producer_dummy_flush, u'') - strict_eq(next(iterable), u'') + iterable = IterIO(producer_dummy_flush, b"") + strict_eq(next(iterable), b"") + iterable = IterIO(producer_dummy_flush, u"") + strict_eq(next(iterable), u"") diff --git a/tests/contrib/test_securecookie.py b/tests/contrib/test_securecookie.py index 80e82147..7231ac88 100644 --- a/tests/contrib/test_securecookie.py +++ b/tests/contrib/test_securecookie.py @@ -9,54 +9,59 @@ :license: BSD-3-Clause """ import json + import pytest from werkzeug._compat import to_native -from werkzeug.utils import parse_cookie -from werkzeug.wrappers import Request, Response from werkzeug.contrib.securecookie import SecureCookie +from werkzeug.utils import parse_cookie +from werkzeug.wrappers import Request +from werkzeug.wrappers import Response def test_basic_support(): - c = SecureCookie(secret_key=b'foo') + c = SecureCookie(secret_key=b"foo") assert c.new assert not c.modified assert not c.should_save - c['x'] = 42 + c["x"] = 42 assert c.modified assert c.should_save s = c.serialize() - c2 = SecureCookie.unserialize(s, b'foo') + c2 = SecureCookie.unserialize(s, b"foo") assert c is not c2 assert not c2.new assert not c2.modified assert not c2.should_save assert c2 == c - c3 = SecureCookie.unserialize(s, b'wrong foo') + c3 = SecureCookie.unserialize(s, b"wrong foo") assert not c3.modified assert not c3.new assert c3 == {} - c4 = SecureCookie({'x': 42}, 'foo') + c4 = SecureCookie({"x": 42}, "foo") c4_serialized = c4.serialize() - assert SecureCookie.unserialize(c4_serialized, 'foo') == c4 + assert SecureCookie.unserialize(c4_serialized, "foo") == c4 def test_wrapper_support(): req = Request.from_values() resp = Response() - c = SecureCookie.load_cookie(req, secret_key=b'foo') + c = SecureCookie.load_cookie(req, secret_key=b"foo") assert c.new - c['foo'] = 42 - assert c.secret_key == b'foo' + c["foo"] = 42 + assert c.secret_key == b"foo" c.save_cookie(resp) - req = Request.from_values(headers={ - 'Cookie': 'session="%s"' % parse_cookie(resp.headers['set-cookie'])['session'] - }) - c2 = SecureCookie.load_cookie(req, secret_key=b'foo') + req = Request.from_values( + headers={ + "Cookie": 'session="%s"' + % parse_cookie(resp.headers["set-cookie"])["session"] + } + ) + c2 = SecureCookie.load_cookie(req, secret_key=b"foo") assert not c2.new assert c2 == c diff --git a/tests/contrib/test_sessions.py b/tests/contrib/test_sessions.py index a941652f..cab0ae56 100644 --- a/tests/contrib/test_sessions.py +++ b/tests/contrib/test_sessions.py @@ -24,7 +24,7 @@ def test_basic_fs_sessions(tmpdir): x = store.new() assert x.new assert not x.modified - x['foo'] = [1, 2, 3] + x["foo"] = [1, 2, 3] assert x.modified store.save(x) @@ -33,7 +33,7 @@ def test_basic_fs_sessions(tmpdir): assert not x2.modified assert x2 is not x assert x2 == x - x2['test'] = 3 + x2["test"] = 3 assert x2.modified assert not x2.new store.save(x2) @@ -67,7 +67,7 @@ def test_renewing_fs_session(tmpdir): def test_fs_session_lising(tmpdir): store = FilesystemSessionStore(str(tmpdir), renew_missing=True) sessions = set() - for x in range(10): + for _ in range(10): sess = store.new() store.save(sess) sessions.add(sess.sid) diff --git a/tests/contrib/test_wrappers.py b/tests/contrib/test_wrappers.py index 7bc2e6c8..fb49337f 100644 --- a/tests/contrib/test_wrappers.py +++ b/tests/contrib/test_wrappers.py @@ -8,78 +8,81 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ - -from werkzeug.contrib import wrappers from werkzeug import routing -from werkzeug.wrappers import Request, Response +from werkzeug.contrib import wrappers +from werkzeug.wrappers import Request +from werkzeug.wrappers import Response def test_reverse_slash_behavior(): class MyRequest(wrappers.ReverseSlashBehaviorRequestMixin, Request): pass - req = MyRequest.from_values('/foo/bar', 'http://example.com/test') - assert req.url == 'http://example.com/test/foo/bar' - assert req.path == 'foo/bar' - assert req.script_root == '/test/' + + req = MyRequest.from_values("/foo/bar", "http://example.com/test") + assert req.url == "http://example.com/test/foo/bar" + assert req.path == "foo/bar" + assert req.script_root == "/test/" # make sure the routing system works with the slashes in # reverse order as well. - map = routing.Map([routing.Rule('/foo/bar', endpoint='foo')]) + map = routing.Map([routing.Rule("/foo/bar", endpoint="foo")]) adapter = map.bind_to_environ(req.environ) - assert adapter.match() == ('foo', {}) + assert adapter.match() == ("foo", {}) adapter = map.bind(req.host, req.script_root) - assert adapter.match(req.path) == ('foo', {}) + assert adapter.match(req.path) == ("foo", {}) def test_dynamic_charset_request_mixin(): class MyRequest(wrappers.DynamicCharsetRequestMixin, Request): pass - env = {'CONTENT_TYPE': 'text/html'} + + env = {"CONTENT_TYPE": "text/html"} req = MyRequest(env) - assert req.charset == 'latin1' + assert req.charset == "latin1" - env = {'CONTENT_TYPE': 'text/html; charset=utf-8'} + env = {"CONTENT_TYPE": "text/html; charset=utf-8"} req = MyRequest(env) - assert req.charset == 'utf-8' + assert req.charset == "utf-8" - env = {'CONTENT_TYPE': 'application/octet-stream'} + env = {"CONTENT_TYPE": "application/octet-stream"} req = MyRequest(env) - assert req.charset == 'latin1' - assert req.url_charset == 'latin1' + assert req.charset == "latin1" + assert req.url_charset == "latin1" - MyRequest.url_charset = 'utf-8' - env = {'CONTENT_TYPE': 'application/octet-stream'} + MyRequest.url_charset = "utf-8" + env = {"CONTENT_TYPE": "application/octet-stream"} req = MyRequest(env) - assert req.charset == 'latin1' - assert req.url_charset == 'utf-8' + assert req.charset == "latin1" + assert req.url_charset == "utf-8" def return_ascii(x): return "ascii" - env = {'CONTENT_TYPE': 'text/plain; charset=x-weird-charset'} + + env = {"CONTENT_TYPE": "text/plain; charset=x-weird-charset"} req = MyRequest(env) req.unknown_charset = return_ascii - assert req.charset == 'ascii' - assert req.url_charset == 'utf-8' + assert req.charset == "ascii" + assert req.url_charset == "utf-8" def test_dynamic_charset_response_mixin(): class MyResponse(wrappers.DynamicCharsetResponseMixin, Response): - default_charset = 'utf-7' - resp = MyResponse(mimetype='text/html') - assert resp.charset == 'utf-7' - resp.charset = 'utf-8' - assert resp.charset == 'utf-8' - assert resp.mimetype == 'text/html' - assert resp.mimetype_params == {'charset': 'utf-8'} - resp.mimetype_params['charset'] = 'iso-8859-15' - assert resp.charset == 'iso-8859-15' - resp.set_data(u'Hällo Wörld') - assert b''.join(resp.iter_encoded()) == \ - u'Hällo Wörld'.encode('iso-8859-15') - del resp.headers['content-type'] + default_charset = "utf-7" + + resp = MyResponse(mimetype="text/html") + assert resp.charset == "utf-7" + resp.charset = "utf-8" + assert resp.charset == "utf-8" + assert resp.mimetype == "text/html" + assert resp.mimetype_params == {"charset": "utf-8"} + resp.mimetype_params["charset"] = "iso-8859-15" + assert resp.charset == "iso-8859-15" + resp.set_data(u"Hällo Wörld") + assert b"".join(resp.iter_encoded()) == u"Hällo Wörld".encode("iso-8859-15") + del resp.headers["content-type"] try: - resp.charset = 'utf-8' + resp.charset = "utf-8" except TypeError: pass else: - assert False, 'expected type error on charset setting without ct' + assert False, "expected type error on charset setting without ct" diff --git a/tests/hypothesis/test_urls.py b/tests/hypothesis/test_urls.py index c610714d..61829b3c 100644 --- a/tests/hypothesis/test_urls.py +++ b/tests/hypothesis/test_urls.py @@ -1,5 +1,8 @@ import hypothesis -from hypothesis.strategies import text, dictionaries, lists, integers +from hypothesis.strategies import dictionaries +from hypothesis.strategies import integers +from hypothesis.strategies import lists +from hypothesis.strategies import text from werkzeug import urls from werkzeug.datastructures import OrderedMultiDict @@ -22,8 +25,9 @@ def test_url_encoding_dict_str_list(d): @hypothesis.given(dictionaries(text(), integers())) def test_url_encoding_dict_str_int(d): - assert OrderedMultiDict({k: str(v) for k, v in d.items()}) == \ - urls.url_decode(urls.url_encode(d)) + assert OrderedMultiDict({k: str(v) for k, v in d.items()}) == urls.url_decode( + urls.url_encode(d) + ) @hypothesis.given(text(), text()) diff --git a/tests/multipart/firefox3-2png1txt/request.txt b/tests/multipart/firefox3-2png1txt/request.http Binary files differindex 721e04e3..721e04e3 100644 --- a/tests/multipart/firefox3-2png1txt/request.txt +++ b/tests/multipart/firefox3-2png1txt/request.http diff --git a/tests/multipart/firefox3-2png1txt/text.txt b/tests/multipart/firefox3-2png1txt/text.txt index c87634d5..4491a1e4 100644 --- a/tests/multipart/firefox3-2png1txt/text.txt +++ b/tests/multipart/firefox3-2png1txt/text.txt @@ -1 +1 @@ -example text
\ No newline at end of file +example text diff --git a/tests/multipart/firefox3-2pnglongtext/request.txt b/tests/multipart/firefox3-2pnglongtext/request.http Binary files differindex 489290b6..489290b6 100644 --- a/tests/multipart/firefox3-2pnglongtext/request.txt +++ b/tests/multipart/firefox3-2pnglongtext/request.http diff --git a/tests/multipart/firefox3-2pnglongtext/text.txt b/tests/multipart/firefox3-2pnglongtext/text.txt index 3bf804d5..833ab62e 100644 --- a/tests/multipart/firefox3-2pnglongtext/text.txt +++ b/tests/multipart/firefox3-2pnglongtext/text.txt @@ -1,3 +1,3 @@ --long text --with boundary ---lookalikes--
\ No newline at end of file +--lookalikes-- diff --git a/tests/multipart/ie6-2png1txt/request.txt b/tests/multipart/ie6-2png1txt/request.http Binary files differindex 59fdeae2..59fdeae2 100644 --- a/tests/multipart/ie6-2png1txt/request.txt +++ b/tests/multipart/ie6-2png1txt/request.http diff --git a/tests/multipart/ie6-2png1txt/text.txt b/tests/multipart/ie6-2png1txt/text.txt index 7c465b7a..a32e65e6 100644 --- a/tests/multipart/ie6-2png1txt/text.txt +++ b/tests/multipart/ie6-2png1txt/text.txt @@ -1 +1 @@ -ie6 sucks :-/
\ No newline at end of file +ie6 sucks :-/ diff --git a/tests/multipart/ie7_full_path_request.txt b/tests/multipart/ie7_full_path_request.http Binary files differindex acc4e2e1..acc4e2e1 100644 --- a/tests/multipart/ie7_full_path_request.txt +++ b/tests/multipart/ie7_full_path_request.http diff --git a/tests/multipart/opera8-2png1txt/request.txt b/tests/multipart/opera8-2png1txt/request.http Binary files differindex 8f325914..8f325914 100644 --- a/tests/multipart/opera8-2png1txt/request.txt +++ b/tests/multipart/opera8-2png1txt/request.http diff --git a/tests/multipart/opera8-2png1txt/text.txt b/tests/multipart/opera8-2png1txt/text.txt index ca01cb00..ea10aa51 100644 --- a/tests/multipart/opera8-2png1txt/text.txt +++ b/tests/multipart/opera8-2png1txt/text.txt @@ -1 +1 @@ -blafasel öäü
\ No newline at end of file +blafasel öäü diff --git a/tests/multipart/test_collect.py b/tests/multipart/test_collect.py index 395d8cb0..01b89bbf 100644 --- a/tests/multipart/test_collect.py +++ b/tests/multipart/test_collect.py @@ -3,55 +3,58 @@ Hacky helper application to collect form data. """ from werkzeug.serving import run_simple -from werkzeug.wrappers import Request, Response +from werkzeug.wrappers import Request +from werkzeug.wrappers import Response def copy_stream(request): from os import mkdir from time import time - folder = 'request-%d' % time() + + folder = "request-%d" % time() mkdir(folder) environ = request.environ - f = open(folder + '/request.txt', 'wb+') - f.write(environ['wsgi.input'].read(int(environ['CONTENT_LENGTH']))) + f = open(folder + "/request.http", "wb+") + f.write(environ["wsgi.input"].read(int(environ["CONTENT_LENGTH"]))) f.flush() f.seek(0) - environ['wsgi.input'] = f + environ["wsgi.input"] = f request.stat_folder = folder def stats(request): copy_stream(request) - f1 = request.files['file1'] - f2 = request.files['file2'] - text = request.form['text'] - f1.save(request.stat_folder + '/file1.bin') - f2.save(request.stat_folder + '/file2.bin') - with open(request.stat_folder + '/text.txt', 'w') as f: - f.write(text.encode('utf-8')) - return Response('Done.') + f1 = request.files["file1"] + f2 = request.files["file2"] + text = request.form["text"] + f1.save(request.stat_folder + "/file1.bin") + f2.save(request.stat_folder + "/file2.bin") + with open(request.stat_folder + "/text.txt", "w") as f: + f.write(text.encode("utf-8")) + return Response("Done.") def upload_file(request): - return Response(''' - <h1>Upload File</h1> - <form action="" method="post" enctype="multipart/form-data"> - <input type="file" name="file1"><br> - <input type="file" name="file2"><br> - <textarea name="text"></textarea><br> - <input type="submit" value="Send"> - </form> - ''', mimetype='text/html') + return Response( + """<h1>Upload File</h1> + <form action="" method="post" enctype="multipart/form-data"> + <input type="file" name="file1"><br> + <input type="file" name="file2"><br> + <textarea name="text"></textarea><br> + <input type="submit" value="Send"> + </form>""", + mimetype="text/html", + ) def application(environ, start_responseonse): request = Request(environ) - if request.method == 'POST': + if request.method == "POST": response = stats(request) else: response = upload_file(request) return response(environ, start_responseonse) -if __name__ == '__main__': - run_simple('localhost', 5000, application, use_debugger=True) +if __name__ == "__main__": + run_simple("localhost", 5000, application, use_debugger=True) diff --git a/tests/multipart/webkit3-2png1txt/request.txt b/tests/multipart/webkit3-2png1txt/request.http Binary files differindex b4ce0eef..b4ce0eef 100644 --- a/tests/multipart/webkit3-2png1txt/request.txt +++ b/tests/multipart/webkit3-2png1txt/request.http diff --git a/tests/multipart/webkit3-2png1txt/text.txt b/tests/multipart/webkit3-2png1txt/text.txt index baa13008..17537909 100644 --- a/tests/multipart/webkit3-2png1txt/text.txt +++ b/tests/multipart/webkit3-2png1txt/text.txt @@ -1 +1 @@ -this is another text with ümläüts
\ No newline at end of file +this is another text with ümläüts diff --git a/tests/res/chunked.txt b/tests/res/chunked.http index 44aa71f6..44aa71f6 100644 --- a/tests/res/chunked.txt +++ b/tests/res/chunked.http diff --git a/tests/test_compat.py b/tests/test_compat.py index ad8e8a4d..98851ba2 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +# flake8: noqa """ tests.compat ~~~~~~~~~~~~ @@ -8,25 +9,32 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ - -# This file shouldn't be linted: -# flake8: noqa - -import warnings - -from werkzeug.wrappers import Response from werkzeug.test import create_environ +from werkzeug.wrappers import Response def test_old_imports(): - from werkzeug.utils import Headers, MultiDict, CombinedMultiDict, \ - Headers, EnvironHeaders - from werkzeug.http import Accept, MIMEAccept, CharsetAccept, \ - LanguageAccept, ETags, HeaderSet, WWWAuthenticate, \ - Authorization + from werkzeug.utils import ( + Headers, + MultiDict, + CombinedMultiDict, + Headers, + EnvironHeaders, + ) + from werkzeug.http import ( + Accept, + MIMEAccept, + CharsetAccept, + LanguageAccept, + ETags, + HeaderSet, + WWWAuthenticate, + Authorization, + ) def test_exposed_werkzeug_mod(): import werkzeug + for key in werkzeug.__all__: getattr(werkzeug, key) diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 016e0544..20aabf62 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -18,43 +18,46 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ - import io -import pytest -import tempfile - -from tests import strict_eq - - import pickle +import tempfile from contextlib import contextmanager -from copy import copy, deepcopy +from copy import copy +from copy import deepcopy + +import pytest -from werkzeug import datastructures, http -from werkzeug._compat import iterkeys, itervalues, iteritems, iterlists, \ - iterlistvalues, text_type, PY2 +from . import strict_eq +from werkzeug import datastructures +from werkzeug import http +from werkzeug._compat import iteritems +from werkzeug._compat import iterkeys +from werkzeug._compat import iterlists +from werkzeug._compat import iterlistvalues +from werkzeug._compat import itervalues +from werkzeug._compat import PY2 +from werkzeug._compat import text_type from werkzeug.datastructures import Range from werkzeug.exceptions import BadRequestKeyError class TestNativeItermethods(object): - def test_basic(self): - @datastructures.native_itermethods(['keys', 'values', 'items']) + @datastructures.native_itermethods(["keys", "values", "items"]) class StupidDict(object): - def keys(self, multi=1): - return iter(['a', 'b', 'c'] * multi) + return iter(["a", "b", "c"] * multi) def values(self, multi=1): return iter([1, 2, 3] * multi) def items(self, multi=1): - return iter(zip(iterkeys(self, multi=multi), - itervalues(self, multi=multi))) + return iter( + zip(iterkeys(self, multi=multi), itervalues(self, multi=multi)) + ) d = StupidDict() - expected_keys = ['a', 'b', 'c'] + expected_keys = ["a", "b", "c"] expected_values = [1, 2, 3] expected_items = list(zip(expected_keys, expected_values)) @@ -81,8 +84,8 @@ class _MutableMultiDictTests(object): cls.__module__ = module d = cls() cls.__module__ = old - d.setlist(b'foo', [1, 2, 3, 4]) - d.setlist(b'bar', b'foo bar baz'.split()) + d.setlist(b"foo", [1, 2, 3, 4]) + d.setlist(b"bar", b"foo bar baz".split()) return d for protocol in range(pickle.HIGHEST_PROTOCOL + 1): @@ -91,157 +94,173 @@ class _MutableMultiDictTests(object): ud = pickle.loads(s) assert type(ud) == type(d) assert ud == d - alternative = pickle.dumps(create_instance('werkzeug'), protocol) + alternative = pickle.dumps(create_instance("werkzeug"), protocol) assert pickle.loads(alternative) == d - ud[b'newkey'] = b'bla' + ud[b"newkey"] = b"bla" assert ud != d def test_basic_interface(self): md = self.storage_class() assert isinstance(md, dict) - mapping = [('a', 1), ('b', 2), ('a', 2), ('d', 3), - ('a', 1), ('a', 3), ('d', 4), ('c', 3)] + mapping = [ + ("a", 1), + ("b", 2), + ("a", 2), + ("d", 3), + ("a", 1), + ("a", 3), + ("d", 4), + ("c", 3), + ] md = self.storage_class(mapping) # simple getitem gives the first value - assert md['a'] == 1 - assert md['c'] == 3 + assert md["a"] == 1 + assert md["c"] == 3 with pytest.raises(KeyError): - md['e'] - assert md.get('a') == 1 + md["e"] + assert md.get("a") == 1 # list getitem - assert md.getlist('a') == [1, 2, 1, 3] - assert md.getlist('d') == [3, 4] + assert md.getlist("a") == [1, 2, 1, 3] + assert md.getlist("d") == [3, 4] # do not raise if key not found - assert md.getlist('x') == [] + assert md.getlist("x") == [] # simple setitem overwrites all values - md['a'] = 42 - assert md.getlist('a') == [42] + md["a"] = 42 + assert md.getlist("a") == [42] # list setitem - md.setlist('a', [1, 2, 3]) - assert md['a'] == 1 - assert md.getlist('a') == [1, 2, 3] + md.setlist("a", [1, 2, 3]) + assert md["a"] == 1 + assert md.getlist("a") == [1, 2, 3] # verify that it does not change original lists l1 = [1, 2, 3] - md.setlist('a', l1) + md.setlist("a", l1) del l1[:] - assert md['a'] == 1 + assert md["a"] == 1 # setdefault, setlistdefault - assert md.setdefault('u', 23) == 23 - assert md.getlist('u') == [23] - del md['u'] + assert md.setdefault("u", 23) == 23 + assert md.getlist("u") == [23] + del md["u"] - md.setlist('u', [-1, -2]) + md.setlist("u", [-1, -2]) # delitem - del md['u'] + del md["u"] with pytest.raises(KeyError): - md['u'] - del md['d'] - assert md.getlist('d') == [] + md["u"] + del md["d"] + assert md.getlist("d") == [] # keys, values, items, lists - assert list(sorted(md.keys())) == ['a', 'b', 'c'] - assert list(sorted(iterkeys(md))) == ['a', 'b', 'c'] + assert list(sorted(md.keys())) == ["a", "b", "c"] + assert list(sorted(iterkeys(md))) == ["a", "b", "c"] assert list(sorted(itervalues(md))) == [1, 2, 3] assert list(sorted(itervalues(md))) == [1, 2, 3] - assert list(sorted(md.items())) == [('a', 1), ('b', 2), ('c', 3)] - assert list(sorted(md.items(multi=True))) == \ - [('a', 1), ('a', 2), ('a', 3), ('b', 2), ('c', 3)] - assert list(sorted(iteritems(md))) == [('a', 1), ('b', 2), ('c', 3)] - assert list(sorted(iteritems(md, multi=True))) == \ - [('a', 1), ('a', 2), ('a', 3), ('b', 2), ('c', 3)] + assert list(sorted(md.items())) == [("a", 1), ("b", 2), ("c", 3)] + assert list(sorted(md.items(multi=True))) == [ + ("a", 1), + ("a", 2), + ("a", 3), + ("b", 2), + ("c", 3), + ] + assert list(sorted(iteritems(md))) == [("a", 1), ("b", 2), ("c", 3)] + assert list(sorted(iteritems(md, multi=True))) == [ + ("a", 1), + ("a", 2), + ("a", 3), + ("b", 2), + ("c", 3), + ] - assert list(sorted(md.lists())) == \ - [('a', [1, 2, 3]), ('b', [2]), ('c', [3])] - assert list(sorted(iterlists(md))) == \ - [('a', [1, 2, 3]), ('b', [2]), ('c', [3])] + assert list(sorted(md.lists())) == [("a", [1, 2, 3]), ("b", [2]), ("c", [3])] + assert list(sorted(iterlists(md))) == [("a", [1, 2, 3]), ("b", [2]), ("c", [3])] # copy method c = md.copy() - assert c['a'] == 1 - assert c.getlist('a') == [1, 2, 3] + assert c["a"] == 1 + assert c.getlist("a") == [1, 2, 3] # copy method 2 c = copy(md) - assert c['a'] == 1 - assert c.getlist('a') == [1, 2, 3] + assert c["a"] == 1 + assert c.getlist("a") == [1, 2, 3] # deepcopy method c = md.deepcopy() - assert c['a'] == 1 - assert c.getlist('a') == [1, 2, 3] + assert c["a"] == 1 + assert c.getlist("a") == [1, 2, 3] # deepcopy method 2 c = deepcopy(md) - assert c['a'] == 1 - assert c.getlist('a') == [1, 2, 3] + assert c["a"] == 1 + assert c.getlist("a") == [1, 2, 3] # update with a multidict - od = self.storage_class([('a', 4), ('a', 5), ('y', 0)]) + od = self.storage_class([("a", 4), ("a", 5), ("y", 0)]) md.update(od) - assert md.getlist('a') == [1, 2, 3, 4, 5] - assert md.getlist('y') == [0] + assert md.getlist("a") == [1, 2, 3, 4, 5] + assert md.getlist("y") == [0] # update with a regular dict md = c - od = {'a': 4, 'y': 0} + od = {"a": 4, "y": 0} md.update(od) - assert md.getlist('a') == [1, 2, 3, 4] - assert md.getlist('y') == [0] + assert md.getlist("a") == [1, 2, 3, 4] + assert md.getlist("y") == [0] # pop, poplist, popitem, popitemlist - assert md.pop('y') == 0 - assert 'y' not in md - assert md.poplist('a') == [1, 2, 3, 4] - assert 'a' not in md - assert md.poplist('missing') == [] + assert md.pop("y") == 0 + assert "y" not in md + assert md.poplist("a") == [1, 2, 3, 4] + assert "a" not in md + assert md.poplist("missing") == [] # remaining: b=2, c=3 popped = md.popitem() - assert popped in [('b', 2), ('c', 3)] + assert popped in [("b", 2), ("c", 3)] popped = md.popitemlist() - assert popped in [('b', [2]), ('c', [3])] + assert popped in [("b", [2]), ("c", [3])] # type conversion - md = self.storage_class({'a': '4', 'b': ['2', '3']}) - assert md.get('a', type=int) == 4 - assert md.getlist('b', type=int) == [2, 3] + md = self.storage_class({"a": "4", "b": ["2", "3"]}) + assert md.get("a", type=int) == 4 + assert md.getlist("b", type=int) == [2, 3] # repr - md = self.storage_class([('a', 1), ('a', 2), ('b', 3)]) + md = self.storage_class([("a", 1), ("a", 2), ("b", 3)]) assert "('a', 1)" in repr(md) assert "('a', 2)" in repr(md) assert "('b', 3)" in repr(md) # add and getlist - md.add('c', '42') - md.add('c', '23') - assert md.getlist('c') == ['42', '23'] - md.add('c', 'blah') - assert md.getlist('c', type=int) == [42, 23] + md.add("c", "42") + md.add("c", "23") + assert md.getlist("c") == ["42", "23"] + md.add("c", "blah") + assert md.getlist("c", type=int) == [42, 23] # setdefault md = self.storage_class() - md.setdefault('x', []).append(42) - md.setdefault('x', []).append(23) - assert md['x'] == [42, 23] + md.setdefault("x", []).append(42) + md.setdefault("x", []).append(23) + assert md["x"] == [42, 23] # to dict md = self.storage_class() - md['foo'] = 42 - md.add('bar', 1) - md.add('bar', 2) - assert md.to_dict() == {'foo': 42, 'bar': 1} - assert md.to_dict(flat=False) == {'foo': [42], 'bar': [1, 2]} + md["foo"] = 42 + md.add("bar", 1) + md.add("bar", 2) + assert md.to_dict() == {"foo": 42, "bar": 1} + assert md.to_dict(flat=False) == {"foo": [42], "bar": [1, 2]} # popitem from empty dict with pytest.raises(KeyError): @@ -256,9 +275,9 @@ class _MutableMultiDictTests(object): # setlist works md = self.storage_class() - md['foo'] = 42 - md.setlist('foo', [1, 2]) - assert md.getlist('foo') == [1, 2] + md["foo"] = 42 + md.setlist("foo", [1, 2]) + assert md.getlist("foo") == [1, 2] class _ImmutableDictTests(object): @@ -267,33 +286,33 @@ class _ImmutableDictTests(object): def test_follows_dict_interface(self): cls = self.storage_class - data = {'foo': 1, 'bar': 2, 'baz': 3} + data = {"foo": 1, "bar": 2, "baz": 3} d = cls(data) - assert d['foo'] == 1 - assert d['bar'] == 2 - assert d['baz'] == 3 - assert sorted(d.keys()) == ['bar', 'baz', 'foo'] - assert 'foo' in d - assert 'foox' not in d + assert d["foo"] == 1 + assert d["bar"] == 2 + assert d["baz"] == 3 + assert sorted(d.keys()) == ["bar", "baz", "foo"] + assert "foo" in d + assert "foox" not in d assert len(d) == 3 def test_copies_are_mutable(self): cls = self.storage_class - immutable = cls({'a': 1}) + immutable = cls({"a": 1}) with pytest.raises(TypeError): - immutable.pop('a') + immutable.pop("a") mutable = immutable.copy() - mutable.pop('a') - assert 'a' in immutable + mutable.pop("a") + assert "a" in immutable assert mutable is not immutable assert copy(immutable) is immutable def test_dict_is_hashable(self): cls = self.storage_class - immutable = cls({'a': 1, 'b': 2}) - immutable2 = cls({'a': 2, 'b': 2}) + immutable = cls({"a": 1, "b": 2}) + immutable2 = cls({"a": 2, "b": 2}) x = set([immutable]) assert immutable in x assert immutable2 not in x @@ -317,8 +336,8 @@ class TestImmutableMultiDict(_ImmutableDictTests): def test_multidict_is_hashable(self): cls = self.storage_class - immutable = cls({'a': [1, 2], 'b': 2}) - immutable2 = cls({'a': [1], 'b': 2}) + immutable = cls({"a": [1, 2], "b": 2}) + immutable2 = cls({"a": [1], "b": 2}) x = set([immutable]) assert immutable in x assert immutable2 not in x @@ -341,8 +360,8 @@ class TestImmutableOrderedMultiDict(_ImmutableDictTests): storage_class = datastructures.ImmutableOrderedMultiDict def test_ordered_multidict_is_hashable(self): - a = self.storage_class([('a', 1), ('b', 1), ('a', 2)]) - b = self.storage_class([('a', 1), ('a', 2), ('b', 1)]) + a = self.storage_class([("a", 1), ("b", 1), ("a", 2)]) + b = self.storage_class([("a", 1), ("a", 2), ("b", 1)]) assert hash(a) != hash(b) @@ -350,93 +369,102 @@ class TestMultiDict(_MutableMultiDictTests): storage_class = datastructures.MultiDict def test_multidict_pop(self): - make_d = lambda: self.storage_class({'foo': [1, 2, 3, 4]}) + def make_d(): + return self.storage_class({"foo": [1, 2, 3, 4]}) + d = make_d() - assert d.pop('foo') == 1 + assert d.pop("foo") == 1 assert not d d = make_d() - assert d.pop('foo', 32) == 1 + assert d.pop("foo", 32) == 1 assert not d d = make_d() - assert d.pop('foos', 32) == 32 + assert d.pop("foos", 32) == 32 assert d with pytest.raises(KeyError): - d.pop('foos') + d.pop("foos") def test_multidict_pop_raise_badrequestkeyerror_for_empty_list_value(self): - mapping = [('a', 'b'), ('a', 'c')] + mapping = [("a", "b"), ("a", "c")] md = self.storage_class(mapping) - md.setlistdefault('empty', []) + md.setlistdefault("empty", []) with pytest.raises(KeyError): - md.pop('empty') + md.pop("empty") def test_multidict_popitem_raise_badrequestkeyerror_for_empty_list_value(self): mapping = [] md = self.storage_class(mapping) - md.setlistdefault('empty', []) + md.setlistdefault("empty", []) with pytest.raises(BadRequestKeyError): md.popitem() def test_setlistdefault(self): md = self.storage_class() - assert md.setlistdefault('u', [-1, -2]) == [-1, -2] - assert md.getlist('u') == [-1, -2] - assert md['u'] == -1 + assert md.setlistdefault("u", [-1, -2]) == [-1, -2] + assert md.getlist("u") == [-1, -2] + assert md["u"] == -1 def test_iter_interfaces(self): - mapping = [('a', 1), ('b', 2), ('a', 2), ('d', 3), - ('a', 1), ('a', 3), ('d', 4), ('c', 3)] + mapping = [ + ("a", 1), + ("b", 2), + ("a", 2), + ("d", 3), + ("a", 1), + ("a", 3), + ("d", 4), + ("c", 3), + ] md = self.storage_class(mapping) assert list(zip(md.keys(), md.listvalues())) == list(md.lists()) assert list(zip(md, iterlistvalues(md))) == list(iterlists(md)) - assert list(zip(iterkeys(md), iterlistvalues(md))) == \ - list(iterlists(md)) + assert list(zip(iterkeys(md), iterlistvalues(md))) == list(iterlists(md)) - @pytest.mark.skipif(not PY2, reason='viewmethods work only for the 2-nd version.') + @pytest.mark.skipif(not PY2, reason="viewmethods work only for the 2-nd version.") def test_view_methods(self): - mapping = [('a', 'b'), ('a', 'c')] + mapping = [("a", "b"), ("a", "c")] md = self.storage_class(mapping) - vi = md.viewitems() - vk = md.viewkeys() - vv = md.viewvalues() + vi = md.viewitems() # noqa: B302 + vk = md.viewkeys() # noqa: B302 + vv = md.viewvalues() # noqa: B302 assert list(vi) == list(md.items()) assert list(vk) == list(md.keys()) assert list(vv) == list(md.values()) - md['k'] = 'n' + md["k"] = "n" assert list(vi) == list(md.items()) assert list(vk) == list(md.keys()) assert list(vv) == list(md.values()) - @pytest.mark.skipif(not PY2, reason='viewmethods work only for the 2-nd version.') + @pytest.mark.skipif(not PY2, reason="viewmethods work only for the 2-nd version.") def test_viewitems_with_multi(self): - mapping = [('a', 'b'), ('a', 'c')] + mapping = [("a", "b"), ("a", "c")] md = self.storage_class(mapping) - vi = md.viewitems(multi=True) + vi = md.viewitems(multi=True) # noqa: B302 assert list(vi) == list(md.items(multi=True)) - md['k'] = 'n' + md["k"] = "n" assert list(vi) == list(md.items(multi=True)) def test_getitem_raise_badrequestkeyerror_for_empty_list_value(self): - mapping = [('a', 'b'), ('a', 'c')] + mapping = [("a", "b"), ("a", "c")] md = self.storage_class(mapping) - md.setlistdefault('empty', []) + md.setlistdefault("empty", []) with pytest.raises(KeyError): - md['empty'] + md["empty"] class TestOrderedMultiDict(_MutableMultiDictTests): @@ -447,90 +475,93 @@ class TestOrderedMultiDict(_MutableMultiDictTests): d = cls() assert not d - d.add('foo', 'bar') + d.add("foo", "bar") assert len(d) == 1 - d.add('foo', 'baz') + d.add("foo", "baz") assert len(d) == 1 - assert list(iteritems(d)) == [('foo', 'bar')] - assert list(d) == ['foo'] - assert list(iteritems(d, multi=True)) == \ - [('foo', 'bar'), ('foo', 'baz')] - del d['foo'] + assert list(iteritems(d)) == [("foo", "bar")] + assert list(d) == ["foo"] + assert list(iteritems(d, multi=True)) == [("foo", "bar"), ("foo", "baz")] + del d["foo"] assert not d assert len(d) == 0 assert list(d) == [] - d.update([('foo', 1), ('foo', 2), ('bar', 42)]) - d.add('foo', 3) - assert d.getlist('foo') == [1, 2, 3] - assert d.getlist('bar') == [42] - assert list(iteritems(d)) == [('foo', 1), ('bar', 42)] + d.update([("foo", 1), ("foo", 2), ("bar", 42)]) + d.add("foo", 3) + assert d.getlist("foo") == [1, 2, 3] + assert d.getlist("bar") == [42] + assert list(iteritems(d)) == [("foo", 1), ("bar", 42)] - expected = ['foo', 'bar'] + expected = ["foo", "bar"] assert list(d.keys()) == expected assert list(d) == expected assert list(iterkeys(d)) == expected - assert list(iteritems(d, multi=True)) == \ - [('foo', 1), ('foo', 2), ('bar', 42), ('foo', 3)] + assert list(iteritems(d, multi=True)) == [ + ("foo", 1), + ("foo", 2), + ("bar", 42), + ("foo", 3), + ] assert len(d) == 2 - assert d.pop('foo') == 1 - assert d.pop('blafasel', None) is None - assert d.pop('blafasel', 42) == 42 + assert d.pop("foo") == 1 + assert d.pop("blafasel", None) is None + assert d.pop("blafasel", 42) == 42 assert len(d) == 1 - assert d.poplist('bar') == [42] + assert d.poplist("bar") == [42] assert not d - d.get('missingkey') is None + d.get("missingkey") is None - d.add('foo', 42) - d.add('foo', 23) - d.add('bar', 2) - d.add('foo', 42) + d.add("foo", 42) + d.add("foo", 23) + d.add("bar", 2) + d.add("foo", 42) assert d == datastructures.MultiDict(d) id = self.storage_class(d) assert d == id - d.add('foo', 2) + d.add("foo", 2) assert d != id - d.update({'blah': [1, 2, 3]}) - assert d['blah'] == 1 - assert d.getlist('blah') == [1, 2, 3] + d.update({"blah": [1, 2, 3]}) + assert d["blah"] == 1 + assert d.getlist("blah") == [1, 2, 3] # setlist works d = self.storage_class() - d['foo'] = 42 - d.setlist('foo', [1, 2]) - assert d.getlist('foo') == [1, 2] + d["foo"] = 42 + d.setlist("foo", [1, 2]) + assert d.getlist("foo") == [1, 2] with pytest.raises(BadRequestKeyError): - d.pop('missing') + d.pop("missing") with pytest.raises(BadRequestKeyError): - d['missing'] + d["missing"] # popping d = self.storage_class() - d.add('foo', 23) - d.add('foo', 42) - d.add('foo', 1) - assert d.popitem() == ('foo', 23) + d.add("foo", 23) + d.add("foo", 42) + d.add("foo", 1) + assert d.popitem() == ("foo", 23) with pytest.raises(BadRequestKeyError): d.popitem() assert not d - d.add('foo', 23) - d.add('foo', 42) - d.add('foo', 1) - assert d.popitemlist() == ('foo', [23, 42, 1]) + d.add("foo", 23) + d.add("foo", 42) + d.add("foo", 1) + assert d.popitemlist() == ("foo", [23, 42, 1]) with pytest.raises(BadRequestKeyError): d.popitemlist() # Unhashable d = self.storage_class() - d.add('foo', 23) + d.add("foo", 23) pytest.raises(TypeError, hash, d) def test_iterables(self): @@ -538,92 +569,91 @@ class TestOrderedMultiDict(_MutableMultiDictTests): b = datastructures.MultiDict((("key_b", "value_b"),)) ab = datastructures.CombinedMultiDict((a, b)) - assert sorted(ab.lists()) == [('key_a', ['value_a']), ('key_b', ['value_b'])] - assert sorted(ab.listvalues()) == [['value_a'], ['value_b']] + assert sorted(ab.lists()) == [("key_a", ["value_a"]), ("key_b", ["value_b"])] + assert sorted(ab.listvalues()) == [["value_a"], ["value_b"]] assert sorted(ab.keys()) == ["key_a", "key_b"] - assert sorted(iterlists(ab)) == [('key_a', ['value_a']), ('key_b', ['value_b'])] - assert sorted(iterlistvalues(ab)) == [['value_a'], ['value_b']] + assert sorted(iterlists(ab)) == [("key_a", ["value_a"]), ("key_b", ["value_b"])] + assert sorted(iterlistvalues(ab)) == [["value_a"], ["value_b"]] assert sorted(iterkeys(ab)) == ["key_a", "key_b"] def test_get_description(self): data = datastructures.OrderedMultiDict() with pytest.raises(BadRequestKeyError) as exc_info: - data['baz'] + data["baz"] - assert 'baz' in exc_info.value.get_description() + assert "baz" in exc_info.value.get_description() with pytest.raises(BadRequestKeyError) as exc_info: - data.pop('baz') + data.pop("baz") - assert 'baz' in exc_info.value.get_description() + assert "baz" in exc_info.value.get_description() exc_info.value.args = () - assert 'baz' not in exc_info.value.get_description() + assert "baz" not in exc_info.value.get_description() class TestTypeConversionDict(object): storage_class = datastructures.TypeConversionDict def test_value_conversion(self): - d = self.storage_class(foo='1') - assert d.get('foo', type=int) == 1 + d = self.storage_class(foo="1") + assert d.get("foo", type=int) == 1 def test_return_default_when_conversion_is_not_possible(self): - d = self.storage_class(foo='bar') - assert d.get('foo', default=-1, type=int) == -1 + d = self.storage_class(foo="bar") + assert d.get("foo", default=-1, type=int) == -1 def test_propagate_exceptions_in_conversion(self): - d = self.storage_class(foo='bar') - switch = {'a': 1} + d = self.storage_class(foo="bar") + switch = {"a": 1} with pytest.raises(KeyError): - d.get('foo', type=lambda x: switch[x]) + d.get("foo", type=lambda x: switch[x]) class TestCombinedMultiDict(object): storage_class = datastructures.CombinedMultiDict def test_basic_interface(self): - d1 = datastructures.MultiDict([('foo', '1')]) - d2 = datastructures.MultiDict([('bar', '2'), ('bar', '3')]) + d1 = datastructures.MultiDict([("foo", "1")]) + d2 = datastructures.MultiDict([("bar", "2"), ("bar", "3")]) d = self.storage_class([d1, d2]) # lookup - assert d['foo'] == '1' - assert d['bar'] == '2' - assert d.getlist('bar') == ['2', '3'] + assert d["foo"] == "1" + assert d["bar"] == "2" + assert d.getlist("bar") == ["2", "3"] - assert sorted(d.items()) == [('bar', '2'), ('foo', '1')] - assert sorted(d.items(multi=True)) == \ - [('bar', '2'), ('bar', '3'), ('foo', '1')] - assert 'missingkey' not in d - assert 'foo' in d + assert sorted(d.items()) == [("bar", "2"), ("foo", "1")] + assert sorted(d.items(multi=True)) == [("bar", "2"), ("bar", "3"), ("foo", "1")] + assert "missingkey" not in d + assert "foo" in d # type lookup - assert d.get('foo', type=int) == 1 - assert d.getlist('bar', type=int) == [2, 3] + assert d.get("foo", type=int) == 1 + assert d.getlist("bar", type=int) == [2, 3] # get key errors for missing stuff with pytest.raises(KeyError): - d['missing'] + d["missing"] # make sure that they are immutable with pytest.raises(TypeError): - d['foo'] = 'blub' + d["foo"] = "blub" # copies are mutable d = d.copy() - d['foo'] = 'blub' + d["foo"] = "blub" # make sure lists merges md1 = datastructures.MultiDict((("foo", "bar"),)) md2 = datastructures.MultiDict((("foo", "blafasel"),)) x = self.storage_class((md1, md2)) - assert list(iterlists(x)) == [('foo', ['bar', 'blafasel'])] + assert list(iterlists(x)) == [("foo", ["bar", "blafasel"])] def test_length(self): - d1 = datastructures.MultiDict([('foo', '1')]) - d2 = datastructures.MultiDict([('bar', '2')]) + d1 = datastructures.MultiDict([("foo", "1")]) + d2 = datastructures.MultiDict([("bar", "2")]) assert len(d1) == len(d2) == 1 d = self.storage_class([d1, d2]) assert len(d) == 2 @@ -637,143 +667,135 @@ class TestHeaders(object): def test_basic_interface(self): headers = self.storage_class() - headers.add('Content-Type', 'text/plain') - headers.add('X-Foo', 'bar') - assert 'x-Foo' in headers - assert 'Content-type' in headers + headers.add("Content-Type", "text/plain") + headers.add("X-Foo", "bar") + assert "x-Foo" in headers + assert "Content-type" in headers - headers['Content-Type'] = 'foo/bar' - assert headers['Content-Type'] == 'foo/bar' - assert len(headers.getlist('Content-Type')) == 1 + headers["Content-Type"] = "foo/bar" + assert headers["Content-Type"] == "foo/bar" + assert len(headers.getlist("Content-Type")) == 1 # list conversion - assert headers.to_wsgi_list() == [ - ('Content-Type', 'foo/bar'), - ('X-Foo', 'bar') - ] - assert str(headers) == ( - "Content-Type: foo/bar\r\n" - "X-Foo: bar\r\n" - "\r\n" - ) + assert headers.to_wsgi_list() == [("Content-Type", "foo/bar"), ("X-Foo", "bar")] + assert str(headers) == "Content-Type: foo/bar\r\nX-Foo: bar\r\n\r\n" assert str(self.storage_class()) == "\r\n" # extended add - headers.add('Content-Disposition', 'attachment', filename='foo') - assert headers['Content-Disposition'] == 'attachment; filename=foo' + headers.add("Content-Disposition", "attachment", filename="foo") + assert headers["Content-Disposition"] == "attachment; filename=foo" - headers.add('x', 'y', z='"') - assert headers['x'] == r'y; z="\""' + headers.add("x", "y", z='"') + assert headers["x"] == r'y; z="\""' def test_defaults_and_conversion(self): # defaults - headers = self.storage_class([ - ('Content-Type', 'text/plain'), - ('X-Foo', 'bar'), - ('X-Bar', '1'), - ('X-Bar', '2') - ]) - assert headers.getlist('x-bar') == ['1', '2'] - assert headers.get('x-Bar') == '1' - assert headers.get('Content-Type') == 'text/plain' - - assert headers.setdefault('X-Foo', 'nope') == 'bar' - assert headers.setdefault('X-Bar', 'nope') == '1' - assert headers.setdefault('X-Baz', 'quux') == 'quux' - assert headers.setdefault('X-Baz', 'nope') == 'quux' - headers.pop('X-Baz') + headers = self.storage_class( + [ + ("Content-Type", "text/plain"), + ("X-Foo", "bar"), + ("X-Bar", "1"), + ("X-Bar", "2"), + ] + ) + assert headers.getlist("x-bar") == ["1", "2"] + assert headers.get("x-Bar") == "1" + assert headers.get("Content-Type") == "text/plain" + + assert headers.setdefault("X-Foo", "nope") == "bar" + assert headers.setdefault("X-Bar", "nope") == "1" + assert headers.setdefault("X-Baz", "quux") == "quux" + assert headers.setdefault("X-Baz", "nope") == "quux" + headers.pop("X-Baz") # type conversion - assert headers.get('x-bar', type=int) == 1 - assert headers.getlist('x-bar', type=int) == [1, 2] + assert headers.get("x-bar", type=int) == 1 + assert headers.getlist("x-bar", type=int) == [1, 2] # list like operations - assert headers[0] == ('Content-Type', 'text/plain') - assert headers[:1] == self.storage_class([('Content-Type', 'text/plain')]) + assert headers[0] == ("Content-Type", "text/plain") + assert headers[:1] == self.storage_class([("Content-Type", "text/plain")]) del headers[:2] del headers[-1] - assert headers == self.storage_class([('X-Bar', '1')]) + assert headers == self.storage_class([("X-Bar", "1")]) def test_copying(self): - a = self.storage_class([('foo', 'bar')]) + a = self.storage_class([("foo", "bar")]) b = a.copy() - a.add('foo', 'baz') - assert a.getlist('foo') == ['bar', 'baz'] - assert b.getlist('foo') == ['bar'] + a.add("foo", "baz") + assert a.getlist("foo") == ["bar", "baz"] + assert b.getlist("foo") == ["bar"] def test_popping(self): - headers = self.storage_class([('a', 1)]) - assert headers.pop('a') == 1 - assert headers.pop('b', 2) == 2 + headers = self.storage_class([("a", 1)]) + assert headers.pop("a") == 1 + assert headers.pop("b", 2) == 2 with pytest.raises(KeyError): - headers.pop('c') + headers.pop("c") def test_set_arguments(self): a = self.storage_class() - a.set('Content-Disposition', 'useless') - a.set('Content-Disposition', 'attachment', filename='foo') - assert a['Content-Disposition'] == 'attachment; filename=foo' + a.set("Content-Disposition", "useless") + a.set("Content-Disposition", "attachment", filename="foo") + assert a["Content-Disposition"] == "attachment; filename=foo" def test_reject_newlines(self): h = self.storage_class() - for variation in 'foo\nbar', 'foo\r\nbar', 'foo\rbar': + for variation in "foo\nbar", "foo\r\nbar", "foo\rbar": with pytest.raises(ValueError): - h['foo'] = variation + h["foo"] = variation with pytest.raises(ValueError): - h.add('foo', variation) + h.add("foo", variation) with pytest.raises(ValueError): - h.add('foo', 'test', option=variation) + h.add("foo", "test", option=variation) with pytest.raises(ValueError): - h.set('foo', variation) + h.set("foo", variation) with pytest.raises(ValueError): - h.set('foo', 'test', option=variation) + h.set("foo", "test", option=variation) def test_slicing(self): # there's nothing wrong with these being native strings # Headers doesn't care about the data types h = self.storage_class() - h.set('X-Foo-Poo', 'bleh') - h.set('Content-Type', 'application/whocares') - h.set('X-Forwarded-For', '192.168.0.123') - h[:] = [(k, v) for k, v in h if k.startswith(u'X-')] - assert list(h) == [ - ('X-Foo-Poo', 'bleh'), - ('X-Forwarded-For', '192.168.0.123') - ] + h.set("X-Foo-Poo", "bleh") + h.set("Content-Type", "application/whocares") + h.set("X-Forwarded-For", "192.168.0.123") + h[:] = [(k, v) for k, v in h if k.startswith(u"X-")] + assert list(h) == [("X-Foo-Poo", "bleh"), ("X-Forwarded-For", "192.168.0.123")] def test_bytes_operations(self): h = self.storage_class() - h.set('X-Foo-Poo', 'bleh') - h.set('X-Whoops', b'\xff') - h.set(b'X-Bytes', b'something') + h.set("X-Foo-Poo", "bleh") + h.set("X-Whoops", b"\xff") + h.set(b"X-Bytes", b"something") - assert h.get('x-foo-poo', as_bytes=True) == b'bleh' - assert h.get('x-whoops', as_bytes=True) == b'\xff' - assert h.get('x-bytes') == 'something' + assert h.get("x-foo-poo", as_bytes=True) == b"bleh" + assert h.get("x-whoops", as_bytes=True) == b"\xff" + assert h.get("x-bytes") == "something" def test_to_wsgi_list(self): h = self.storage_class() - h.set(u'Key', u'Value') + h.set(u"Key", u"Value") for key, value in h.to_wsgi_list(): if PY2: - strict_eq(key, b'Key') - strict_eq(value, b'Value') + strict_eq(key, b"Key") + strict_eq(value, b"Value") else: - strict_eq(key, u'Key') - strict_eq(value, u'Value') + strict_eq(key, u"Key") + strict_eq(value, u"Value") def test_to_wsgi_list_bytes(self): h = self.storage_class() - h.set(b'Key', b'Value') + h.set(b"Key", b"Value") for key, value in h.to_wsgi_list(): if PY2: - strict_eq(key, b'Key') - strict_eq(value, b'Value') + strict_eq(key, b"Key") + strict_eq(value, b"Value") else: - strict_eq(key, u'Key') - strict_eq(value, u'Value') + strict_eq(key, u"Key") + strict_eq(value, u"Value") class TestEnvironHeaders(object): @@ -783,64 +805,53 @@ class TestEnvironHeaders(object): # this happens in multiple WSGI servers because they # use a vary naive way to convert the headers; broken_env = { - 'HTTP_CONTENT_TYPE': 'text/html', - 'CONTENT_TYPE': 'text/html', - 'HTTP_CONTENT_LENGTH': '0', - 'CONTENT_LENGTH': '0', - 'HTTP_ACCEPT': '*', - 'wsgi.version': (1, 0) + "HTTP_CONTENT_TYPE": "text/html", + "CONTENT_TYPE": "text/html", + "HTTP_CONTENT_LENGTH": "0", + "CONTENT_LENGTH": "0", + "HTTP_ACCEPT": "*", + "wsgi.version": (1, 0), } headers = self.storage_class(broken_env) assert headers assert len(headers) == 3 assert sorted(headers) == [ - ('Accept', '*'), - ('Content-Length', '0'), - ('Content-Type', 'text/html') + ("Accept", "*"), + ("Content-Length", "0"), + ("Content-Type", "text/html"), ] - assert not self.storage_class({'wsgi.version': (1, 0)}) - assert len(self.storage_class({'wsgi.version': (1, 0)})) == 0 + assert not self.storage_class({"wsgi.version": (1, 0)}) + assert len(self.storage_class({"wsgi.version": (1, 0)})) == 0 assert 42 not in headers def test_skip_empty_special_vars(self): - env = { - 'HTTP_X_FOO': '42', - 'CONTENT_TYPE': '', - 'CONTENT_LENGTH': '', - } + env = {"HTTP_X_FOO": "42", "CONTENT_TYPE": "", "CONTENT_LENGTH": ""} headers = self.storage_class(env) - assert dict(headers) == {'X-Foo': '42'} + assert dict(headers) == {"X-Foo": "42"} - env = { - 'HTTP_X_FOO': '42', - 'CONTENT_TYPE': '', - 'CONTENT_LENGTH': '0', - } + env = {"HTTP_X_FOO": "42", "CONTENT_TYPE": "", "CONTENT_LENGTH": "0"} headers = self.storage_class(env) - assert dict(headers) == {'X-Foo': '42', 'Content-Length': '0'} + assert dict(headers) == {"X-Foo": "42", "Content-Length": "0"} def test_return_type_is_unicode(self): # environ contains native strings; we return unicode - headers = self.storage_class({ - 'HTTP_FOO': '\xe2\x9c\x93', - 'CONTENT_TYPE': 'text/plain', - }) - assert headers['Foo'] == u"\xe2\x9c\x93" - assert isinstance(headers['Foo'], text_type) - assert isinstance(headers['Content-Type'], text_type) + headers = self.storage_class( + {"HTTP_FOO": "\xe2\x9c\x93", "CONTENT_TYPE": "text/plain"} + ) + assert headers["Foo"] == u"\xe2\x9c\x93" + assert isinstance(headers["Foo"], text_type) + assert isinstance(headers["Content-Type"], text_type) iter_output = dict(iter(headers)) - assert iter_output['Foo'] == u"\xe2\x9c\x93" - assert isinstance(iter_output['Foo'], text_type) - assert isinstance(iter_output['Content-Type'], text_type) + assert iter_output["Foo"] == u"\xe2\x9c\x93" + assert isinstance(iter_output["Foo"], text_type) + assert isinstance(iter_output["Content-Type"], text_type) def test_bytes_operations(self): - foo_val = '\xff' - h = self.storage_class({ - 'HTTP_X_FOO': foo_val - }) + foo_val = "\xff" + h = self.storage_class({"HTTP_X_FOO": foo_val}) - assert h.get('x-foo', as_bytes=True) == b'\xff' - assert h.get('x-foo') == u'\xff' + assert h.get("x-foo", as_bytes=True) == b"\xff" + assert h.get("x-foo") == u"\xff" class TestHeaderSet(object): @@ -848,21 +859,21 @@ class TestHeaderSet(object): def test_basic_interface(self): hs = self.storage_class() - hs.add('foo') - hs.add('bar') - assert 'Bar' in hs - assert hs.find('foo') == 0 - assert hs.find('BAR') == 1 - assert hs.find('baz') < 0 - hs.discard('missing') - hs.discard('foo') - assert hs.find('foo') < 0 - assert hs.find('bar') == 0 + hs.add("foo") + hs.add("bar") + assert "Bar" in hs + assert hs.find("foo") == 0 + assert hs.find("BAR") == 1 + assert hs.find("baz") < 0 + hs.discard("missing") + hs.discard("foo") + assert hs.find("foo") < 0 + assert hs.find("bar") == 0 with pytest.raises(IndexError): - hs.index('missing') + hs.index("missing") - assert hs.index('bar') == 0 + assert hs.index("bar") == 0 assert hs hs.clear() assert not hs @@ -910,47 +921,44 @@ class TestCallbackDict(object): def test_callback_dict_reads(self): assert_calls, func = make_call_asserter() - initial = {'a': 'foo', 'b': 'bar'} + initial = {"a": "foo", "b": "bar"} dct = self.storage_class(initial=initial, on_update=func) - with assert_calls(0, 'callback triggered by read-only method'): + with assert_calls(0, "callback triggered by read-only method"): # read-only methods - dct['a'] - dct.get('a') - pytest.raises(KeyError, lambda: dct['x']) - 'a' in dct + dct["a"] + dct.get("a") + pytest.raises(KeyError, lambda: dct["x"]) + "a" in dct list(iter(dct)) dct.copy() - with assert_calls(0, 'callback triggered without modification'): + with assert_calls(0, "callback triggered without modification"): # methods that may write but don't - dct.pop('z', None) - dct.setdefault('a') + dct.pop("z", None) + dct.setdefault("a") def test_callback_dict_writes(self): assert_calls, func = make_call_asserter() - initial = {'a': 'foo', 'b': 'bar'} + initial = {"a": "foo", "b": "bar"} dct = self.storage_class(initial=initial, on_update=func) - with assert_calls(8, 'callback not triggered by write method'): + with assert_calls(8, "callback not triggered by write method"): # always-write methods - dct['z'] = 123 - dct['z'] = 123 # must trigger again - del dct['z'] - dct.pop('b', None) - dct.setdefault('x') + dct["z"] = 123 + dct["z"] = 123 # must trigger again + del dct["z"] + dct.pop("b", None) + dct.setdefault("x") dct.popitem() dct.update([]) dct.clear() - with assert_calls(0, 'callback triggered by failed del'): - pytest.raises(KeyError, lambda: dct.__delitem__('x')) - with assert_calls(0, 'callback triggered by failed pop'): - pytest.raises(KeyError, lambda: dct.pop('x')) + with assert_calls(0, "callback triggered by failed del"): + pytest.raises(KeyError, lambda: dct.__delitem__("x")) + with assert_calls(0, "callback triggered by failed pop"): + pytest.raises(KeyError, lambda: dct.pop("x")) class TestCacheControl(object): - def test_repr(self): - cc = datastructures.RequestCacheControl( - [("max-age", "0"), ("private", "True")], - ) + cc = datastructures.RequestCacheControl([("max-age", "0"), ("private", "True")]) assert repr(cc) == "<RequestCacheControl max-age='0' private='True'>" @@ -958,117 +966,118 @@ class TestAccept(object): storage_class = datastructures.Accept def test_accept_basic(self): - accept = self.storage_class([('tinker', 0), ('tailor', 0.333), - ('soldier', 0.667), ('sailor', 1)]) + accept = self.storage_class( + [("tinker", 0), ("tailor", 0.333), ("soldier", 0.667), ("sailor", 1)] + ) # check __getitem__ on indices - assert accept[3] == ('tinker', 0) - assert accept[2] == ('tailor', 0.333) - assert accept[1] == ('soldier', 0.667) - assert accept[0], ('sailor', 1) + assert accept[3] == ("tinker", 0) + assert accept[2] == ("tailor", 0.333) + assert accept[1] == ("soldier", 0.667) + assert accept[0], ("sailor", 1) # check __getitem__ on string - assert accept['tinker'] == 0 - assert accept['tailor'] == 0.333 - assert accept['soldier'] == 0.667 - assert accept['sailor'] == 1 - assert accept['spy'] == 0 + assert accept["tinker"] == 0 + assert accept["tailor"] == 0.333 + assert accept["soldier"] == 0.667 + assert accept["sailor"] == 1 + assert accept["spy"] == 0 # check quality method - assert accept.quality('tinker') == 0 - assert accept.quality('tailor') == 0.333 - assert accept.quality('soldier') == 0.667 - assert accept.quality('sailor') == 1 - assert accept.quality('spy') == 0 + assert accept.quality("tinker") == 0 + assert accept.quality("tailor") == 0.333 + assert accept.quality("soldier") == 0.667 + assert accept.quality("sailor") == 1 + assert accept.quality("spy") == 0 # check __contains__ - assert 'sailor' in accept - assert 'spy' not in accept + assert "sailor" in accept + assert "spy" not in accept # check index method - assert accept.index('tinker') == 3 - assert accept.index('tailor') == 2 - assert accept.index('soldier') == 1 - assert accept.index('sailor') == 0 + assert accept.index("tinker") == 3 + assert accept.index("tailor") == 2 + assert accept.index("soldier") == 1 + assert accept.index("sailor") == 0 with pytest.raises(ValueError): - accept.index('spy') + accept.index("spy") # check find method - assert accept.find('tinker') == 3 - assert accept.find('tailor') == 2 - assert accept.find('soldier') == 1 - assert accept.find('sailor') == 0 - assert accept.find('spy') == -1 + assert accept.find("tinker") == 3 + assert accept.find("tailor") == 2 + assert accept.find("soldier") == 1 + assert accept.find("sailor") == 0 + assert accept.find("spy") == -1 # check to_header method - assert accept.to_header() == \ - 'sailor,soldier;q=0.667,tailor;q=0.333,tinker;q=0' + assert accept.to_header() == "sailor,soldier;q=0.667,tailor;q=0.333,tinker;q=0" # check best_match method - assert accept.best_match(['tinker', 'tailor', 'soldier', 'sailor'], - default=None) == 'sailor' - assert accept.best_match(['tinker', 'tailor', 'soldier'], - default=None) == 'soldier' - assert accept.best_match(['tinker', 'tailor'], default=None) == \ - 'tailor' - assert accept.best_match(['tinker'], default=None) is None - assert accept.best_match(['tinker'], default='x') == 'x' + assert ( + accept.best_match(["tinker", "tailor", "soldier", "sailor"], default=None) + == "sailor" + ) + assert ( + accept.best_match(["tinker", "tailor", "soldier"], default=None) + == "soldier" + ) + assert accept.best_match(["tinker", "tailor"], default=None) == "tailor" + assert accept.best_match(["tinker"], default=None) is None + assert accept.best_match(["tinker"], default="x") == "x" def test_accept_wildcard(self): - accept = self.storage_class([('*', 0), ('asterisk', 1)]) - assert '*' in accept - assert accept.best_match(['asterisk', 'star'], default=None) == \ - 'asterisk' - assert accept.best_match(['star'], default=None) is None + accept = self.storage_class([("*", 0), ("asterisk", 1)]) + assert "*" in accept + assert accept.best_match(["asterisk", "star"], default=None) == "asterisk" + assert accept.best_match(["star"], default=None) is None def test_accept_keep_order(self): - accept = self.storage_class([('*', 1)]) + accept = self.storage_class([("*", 1)]) assert accept.best_match(["alice", "bob"]) == "alice" assert accept.best_match(["bob", "alice"]) == "bob" - accept = self.storage_class([('alice', 1), ('bob', 1)]) + accept = self.storage_class([("alice", 1), ("bob", 1)]) assert accept.best_match(["alice", "bob"]) == "alice" assert accept.best_match(["bob", "alice"]) == "bob" def test_accept_wildcard_specificity(self): - accept = self.storage_class([('asterisk', 0), ('star', 0.5), ('*', 1)]) - assert accept.best_match(['star', 'asterisk'], default=None) == 'star' - assert accept.best_match(['asterisk', 'star'], default=None) == 'star' - assert accept.best_match(['asterisk', 'times'], default=None) == \ - 'times' - assert accept.best_match(['asterisk'], default=None) is None + accept = self.storage_class([("asterisk", 0), ("star", 0.5), ("*", 1)]) + assert accept.best_match(["star", "asterisk"], default=None) == "star" + assert accept.best_match(["asterisk", "star"], default=None) == "star" + assert accept.best_match(["asterisk", "times"], default=None) == "times" + assert accept.best_match(["asterisk"], default=None) is None class TestMIMEAccept(object): storage_class = datastructures.MIMEAccept def test_accept_wildcard_subtype(self): - accept = self.storage_class([('text/*', 1)]) - assert accept.best_match(['text/html'], default=None) == 'text/html' - assert accept.best_match(['image/png', 'text/plain']) == 'text/plain' - assert accept.best_match(['image/png'], default=None) is None + accept = self.storage_class([("text/*", 1)]) + assert accept.best_match(["text/html"], default=None) == "text/html" + assert accept.best_match(["image/png", "text/plain"]) == "text/plain" + assert accept.best_match(["image/png"], default=None) is None def test_accept_wildcard_specificity(self): - accept = self.storage_class([('*/*', 1), ('text/html', 1)]) - assert accept.best_match(['image/png', 'text/html']) == 'text/html' - assert accept.best_match(['image/png', 'text/plain']) == 'image/png' - accept = self.storage_class([('*/*', 1), ('text/html', 1), - ('image/*', 1)]) - assert accept.best_match(['image/png', 'text/html']) == 'text/html' - assert accept.best_match(['text/plain', 'image/png']) == 'image/png' + accept = self.storage_class([("*/*", 1), ("text/html", 1)]) + assert accept.best_match(["image/png", "text/html"]) == "text/html" + assert accept.best_match(["image/png", "text/plain"]) == "image/png" + accept = self.storage_class([("*/*", 1), ("text/html", 1), ("image/*", 1)]) + assert accept.best_match(["image/png", "text/html"]) == "text/html" + assert accept.best_match(["text/plain", "image/png"]) == "image/png" class TestFileStorage(object): storage_class = datastructures.FileStorage def test_mimetype_always_lowercase(self): - file_storage = self.storage_class(content_type='APPLICATION/JSON') - assert file_storage.mimetype == 'application/json' + file_storage = self.storage_class(content_type="APPLICATION/JSON") + assert file_storage.mimetype == "application/json" def test_bytes_proper_sentinel(self): # ensure we iterate over new lines and don't enter into an infinite loop import io + unicode_storage = self.storage_class(io.StringIO(u"one\ntwo")) - for idx, line in enumerate(unicode_storage): + for idx, _line in enumerate(unicode_storage): assert idx < 2 assert idx == 1 binary_storage = self.storage_class(io.BytesIO(b"one\ntwo")) - for idx, line in enumerate(binary_storage): + for idx, _line in enumerate(binary_storage): assert idx < 2 assert idx == 1 - @pytest.mark.skipif(PY2, reason='io.IOBase is only needed in PY3.') + @pytest.mark.skipif(PY2, reason="io.IOBase is only needed in PY3.") @pytest.mark.parametrize("stream", (tempfile.SpooledTemporaryFile, io.BytesIO)) def test_proxy_can_access_stream_attrs(self, stream): """``SpooledTemporaryFile`` doesn't implement some of @@ -1084,9 +1093,7 @@ class TestFileStorage(object): assert hasattr(file_storage, name) -@pytest.mark.parametrize( - "ranges", ([(0, 1), (-5, None)], [(5, None)]) -) +@pytest.mark.parametrize("ranges", ([(0, 1), (-5, None)], [(5, None)])) def test_range_to_header(ranges): header = Range("byes", ranges).to_header() r = http.parse_range_header(header) diff --git a/tests/test_debug.py b/tests/test_debug.py index 43d9ae6d..fa94b653 100644 --- a/tests/test_debug.py +++ b/tests/test_debug.py @@ -8,43 +8,59 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import sys -import re import io +import re +import sys import pytest import requests -from werkzeug.debug import get_machine_id, DebuggedApplication -from werkzeug.debug.repr import debug_repr, DebugReprGenerator, \ - dump, helper +from werkzeug._compat import PY2 +from werkzeug.debug import DebuggedApplication +from werkzeug.debug import get_machine_id from werkzeug.debug.console import HTMLStringO +from werkzeug.debug.repr import debug_repr +from werkzeug.debug.repr import DebugReprGenerator +from werkzeug.debug.repr import dump +from werkzeug.debug.repr import helper from werkzeug.debug.tbtools import Traceback from werkzeug.test import Client -from werkzeug.wrappers import Request, Response -from werkzeug._compat import PY2 +from werkzeug.wrappers import Request +from werkzeug.wrappers import Response class TestDebugRepr(object): - def test_basic_repr(self): - assert debug_repr([]) == u'[]' - assert debug_repr([1, 2]) == \ - u'[<span class="number">1</span>, <span class="number">2</span>]' - assert debug_repr([1, 'test']) == \ - u'[<span class="number">1</span>, <span class="string">\'test\'</span>]' - assert debug_repr([None]) == \ - u'[<span class="object">None</span>]' + assert debug_repr([]) == u"[]" + assert ( + debug_repr([1, 2]) + == u'[<span class="number">1</span>, <span class="number">2</span>]' + ) + assert ( + debug_repr([1, "test"]) + == u'[<span class="number">1</span>, <span class="string">\'test\'</span>]' + ) + assert debug_repr([None]) == u'[<span class="object">None</span>]' def test_string_repr(self): - assert debug_repr('') == u'<span class="string">\'\'</span>' - assert debug_repr('foo') == u'<span class="string">\'foo\'</span>' - assert debug_repr('s' * 80) == u'<span class="string">\''\ - + 's' * 69 + '<span class="extended">'\ - + 's' * 11 + '\'</span></span>' - assert debug_repr('<' * 80) == u'<span class="string">\''\ - + '<' * 69 + '<span class="extended">'\ - + '<' * 11 + '\'</span></span>' + assert debug_repr("") == u"<span class=\"string\">''</span>" + assert debug_repr("foo") == u"<span class=\"string\">'foo'</span>" + assert ( + debug_repr("s" * 80) + == u'<span class="string">\'' + + "s" * 69 + + '<span class="extended">' + + "s" * 11 + + "'</span></span>" + ) + assert ( + debug_repr("<" * 80) + == u'<span class="string">\'' + + "<" * 69 + + '<span class="extended">' + + "<" * 11 + + "'</span></span>" + ) def test_string_subclass_repr(self): class Test(str): @@ -52,16 +68,16 @@ class TestDebugRepr(object): assert debug_repr(Test("foo")) == ( u'<span class="module">tests.test_debug.</span>' - u'Test(<span class="string">\'foo\'</span>)' + u"Test(<span class=\"string\">'foo'</span>)" ) @pytest.mark.skipif(not PY2, reason="u prefix on py2 only") def test_unicode_repr(self): - assert debug_repr(u"foo") == u'<span class="string">u\'foo\'</span>' + assert debug_repr(u"foo") == u"<span class=\"string\">u'foo'</span>" @pytest.mark.skipif(PY2, reason="b prefix on py3 only") def test_bytes_repr(self): - assert debug_repr(b"foo") == u'<span class="string">b\'foo\'</span>' + assert debug_repr(b"foo") == u"<span class=\"string\">b'foo'</span>" def test_sequence_repr(self): assert debug_repr(list(range(20))) == ( @@ -79,51 +95,84 @@ class TestDebugRepr(object): ) def test_mapping_repr(self): - assert debug_repr({}) == u'{}' - assert debug_repr({'foo': 42}) == ( + assert debug_repr({}) == u"{}" + assert debug_repr({"foo": 42}) == ( u'{<span class="pair"><span class="key"><span class="string">\'foo\'' u'</span></span>: <span class="value"><span class="number">42' - u'</span></span></span>}' + u"</span></span></span>}" ) assert debug_repr(dict(zip(range(10), [None] * 10))) == ( - u'{<span class="pair"><span class="key"><span class="number">0</span></span>: <span class="value"><span class="object">None</span></span></span>, <span class="pair"><span class="key"><span class="number">1</span></span>: <span class="value"><span class="object">None</span></span></span>, <span class="pair"><span class="key"><span class="number">2</span></span>: <span class="value"><span class="object">None</span></span></span>, <span class="pair"><span class="key"><span class="number">3</span></span>: <span class="value"><span class="object">None</span></span></span>, <span class="extended"><span class="pair"><span class="key"><span class="number">4</span></span>: <span class="value"><span class="object">None</span></span></span>, <span class="pair"><span class="key"><span class="number">5</span></span>: <span class="value"><span class="object">None</span></span></span>, <span class="pair"><span class="key"><span class="number">6</span></span>: <span class="value"><span class="object">None</span></span></span>, <span class="pair"><span class="key"><span class="number">7</span></span>: <span class="value"><span class="object">None</span></span></span>, <span class="pair"><span class="key"><span class="number">8</span></span>: <span class="value"><span class="object">None</span></span></span>, <span class="pair"><span class="key"><span class="number">9</span></span>: <span class="value"><span class="object">None</span></span></span></span>}' # noqa + u'{<span class="pair"><span class="key"><span class="number">0' + u'</span></span>: <span class="value"><span class="object">None' + u"</span></span></span>, " + u'<span class="pair"><span class="key"><span class="number">1' + u'</span></span>: <span class="value"><span class="object">None' + u"</span></span></span>, " + u'<span class="pair"><span class="key"><span class="number">2' + u'</span></span>: <span class="value"><span class="object">None' + u"</span></span></span>, " + u'<span class="pair"><span class="key"><span class="number">3' + u'</span></span>: <span class="value"><span class="object">None' + u"</span></span></span>, " + u'<span class="extended">' + u'<span class="pair"><span class="key"><span class="number">4' + u'</span></span>: <span class="value"><span class="object">None' + u"</span></span></span>, " + u'<span class="pair"><span class="key"><span class="number">5' + u'</span></span>: <span class="value"><span class="object">None' + u"</span></span></span>, " + u'<span class="pair"><span class="key"><span class="number">6' + u'</span></span>: <span class="value"><span class="object">None' + u"</span></span></span>, " + u'<span class="pair"><span class="key"><span class="number">7' + u'</span></span>: <span class="value"><span class="object">None' + u"</span></span></span>, " + u'<span class="pair"><span class="key"><span class="number">8' + u'</span></span>: <span class="value"><span class="object">None' + u"</span></span></span>, " + u'<span class="pair"><span class="key"><span class="number">9' + u'</span></span>: <span class="value"><span class="object">None' + u"</span></span></span></span>}" ) - assert debug_repr((1, 'zwei', u'drei')) == ( + assert debug_repr((1, "zwei", u"drei")) == ( u'(<span class="number">1</span>, <span class="string">\'' - u'zwei\'</span>, <span class="string">%s\'drei\'</span>)' - ) % ('u' if PY2 else '') + u"zwei'</span>, <span class=\"string\">%s'drei'</span>)" + ) % ("u" if PY2 else "") def test_custom_repr(self): class Foo(object): - def __repr__(self): - return '<Foo 42>' - assert debug_repr(Foo()) == \ - '<span class="object"><Foo 42></span>' + return "<Foo 42>" + + assert debug_repr(Foo()) == '<span class="object"><Foo 42></span>' def test_list_subclass_repr(self): class MyList(list): pass + assert debug_repr(MyList([1, 2])) == ( u'<span class="module">tests.test_debug.</span>MyList([' u'<span class="number">1</span>, <span class="number">2</span>])' ) def test_regex_repr(self): - assert debug_repr(re.compile(r'foo\d')) == \ - u're.compile(<span class="string regex">r\'foo\\d\'</span>)' + assert ( + debug_repr(re.compile(r"foo\d")) + == u"re.compile(<span class=\"string regex\">r'foo\\d'</span>)" + ) # No ur'' in Py3 # https://bugs.python.org/issue15096 - assert debug_repr(re.compile(u'foo\\d')) == ( - u're.compile(<span class="string regex">%sr\'foo\\d\'</span>)' % - ('u' if PY2 else '') + assert debug_repr(re.compile(u"foo\\d")) == ( + u"re.compile(<span class=\"string regex\">%sr'foo\\d'</span>)" + % ("u" if PY2 else "") ) def test_set_repr(self): - assert debug_repr(frozenset('x')) == \ - u'frozenset([<span class="string">\'x\'</span>])' - assert debug_repr(set('x')) == \ - u'set([<span class="string">\'x\'</span>])' + assert ( + debug_repr(frozenset("x")) + == u"frozenset([<span class=\"string\">'x'</span>])" + ) + assert debug_repr(set("x")) == u"set([<span class=\"string\">'x'</span>])" def test_recursive_repr(self): a = [1] @@ -132,13 +181,12 @@ class TestDebugRepr(object): def test_broken_repr(self): class Foo(object): - def __repr__(self): - raise Exception('broken!') + raise Exception("broken!") assert debug_repr(Foo()) == ( u'<span class="brokenrepr"><broken repr (Exception: ' - u'broken!)></span>' + u"broken!)></span>" ) @@ -151,25 +199,24 @@ class Foo(object): class TestDebugHelpers(object): - def test_object_dumping(self): drg = DebugReprGenerator() out = drg.dump_object(Foo()) - assert re.search('Details for tests.test_debug.Foo object at', out) + assert re.search("Details for tests.test_debug.Foo object at", out) assert re.search('<th>x.*<span class="number">42</span>', out, flags=re.DOTALL) assert re.search('<th>y.*<span class="number">23</span>', out, flags=re.DOTALL) assert re.search('<th>z.*<span class="number">15</span>', out, flags=re.DOTALL) - out = drg.dump_object({'x': 42, 'y': 23}) - assert re.search('Contents of', out) + out = drg.dump_object({"x": 42, "y": 23}) + assert re.search("Contents of", out) assert re.search('<th>x.*<span class="number">42</span>', out, flags=re.DOTALL) assert re.search('<th>y.*<span class="number">23</span>', out, flags=re.DOTALL) - out = drg.dump_object({'x': 42, 'y': 23, 23: 11}) - assert not re.search('Contents of', out) + out = drg.dump_object({"x": 42, "y": 23, 23: 11}) + assert not re.search("Contents of", out) - out = drg.dump_locals({'x': 42, 'y': 23}) - assert re.search('Local variables in frame', out) + out = drg.dump_locals({"x": 42, "y": 23}) + assert re.search("Local variables in frame", out) assert re.search('<th>x.*<span class="number">42</span>', out, flags=re.DOTALL) assert re.search('<th>y.*<span class="number">23</span>', out, flags=re.DOTALL) @@ -184,11 +231,11 @@ class TestDebugHelpers(object): finally: sys.stdout = old - assert 'Details for list object at' in x + assert "Details for list object at" in x assert '<span class="number">1</span>' in x - assert 'Local variables in frame' in y - assert '<th>x' in y - assert '<th>old' in y + assert "Local variables in frame" in y + assert "<th>x" in y + assert "<th>old" in y def test_debug_help(self): old = sys.stdout @@ -199,8 +246,8 @@ class TestDebugHelpers(object): finally: sys.stdout = old - assert 'Help on list object' in x - assert '__delitem__' in x + assert "Help on list object" in x + assert "__delitem__" in x @pytest.mark.skipif(PY2, reason="Python 2 doesn't have chained exceptions.") def test_exc_divider_found_on_chained_exception(self): @@ -208,6 +255,7 @@ class TestDebugHelpers(object): def app(request): def do_something(): raise ValueError("inner") + try: do_something() except ValueError: @@ -223,7 +271,6 @@ class TestDebugHelpers(object): class TestTraceback(object): - def test_log(self): try: 1 / 0 @@ -235,12 +282,14 @@ class TestTraceback(object): assert buffer_.getvalue().strip() == traceback.plaintext.strip() def test_sourcelines_encoding(self): - source = (u'# -*- coding: latin1 -*-\n\n' - u'def foo():\n' - u' """höhö"""\n' - u' 1 / 0\n' - u'foo()').encode('latin1') - code = compile(source, filename='lol.py', mode='exec') + source = ( + u"# -*- coding: latin1 -*-\n\n" + u"def foo():\n" + u' """höhö"""\n' + u" 1 / 0\n" + u"foo()" + ).encode("latin1") + code = compile(source, filename="lol.py", mode="exec") try: eval(code) except ZeroDivisionError: @@ -248,23 +297,23 @@ class TestTraceback(object): frames = traceback.frames assert len(frames) == 3 - assert frames[1].filename == 'lol.py' - assert frames[2].filename == 'lol.py' + assert frames[1].filename == "lol.py" + assert frames[2].filename == "lol.py" class Loader(object): - def get_source(self, module): return source frames[1].loader = frames[2].loader = Loader() assert frames[1].sourcelines == frames[2].sourcelines - assert [line.code for line in frames[1].get_annotated_lines()] == \ - [line.code for line in frames[2].get_annotated_lines()] - assert u'höhö' in frames[1].sourcelines[3] + assert [line.code for line in frames[1].get_annotated_lines()] == [ + line.code for line in frames[2].get_annotated_lines() + ] + assert u"höhö" in frames[1].sourcelines[3] def test_filename_encoding(self, tmpdir, monkeypatch): - moduledir = tmpdir.mkdir('föö') - moduledir.join('bar.py').write('def foo():\n 1/0\n') + moduledir = tmpdir.mkdir("föö") + moduledir.join("bar.py").write("def foo():\n 1/0\n") monkeypatch.syspath_prepend(str(moduledir)) import bar @@ -274,7 +323,7 @@ class TestTraceback(object): except ZeroDivisionError: traceback = Traceback(*sys.exc_info()) - assert u'föö' in u'\n'.join(frame.render() for frame in traceback.frames) + assert u"föö" in u"\n".join(frame.render() for frame in traceback.frames) def test_get_machine_id(): @@ -282,9 +331,10 @@ def test_get_machine_id(): assert isinstance(rv, bytes) -@pytest.mark.parametrize('crash', (True, False)) +@pytest.mark.parametrize("crash", (True, False)) def test_basic(dev_server, crash): - server = dev_server(''' + server = dev_server( + """ from werkzeug.debug import DebuggedApplication @DebuggedApplication @@ -293,12 +343,14 @@ def test_basic(dev_server, crash): 1 / 0 start_response('200 OK', [('Content-Type', 'text/html')]) return [b'hello'] - '''.format(crash=crash)) + """.format( + crash=crash + ) + ) r = requests.get(server.url) assert r.status_code == 500 if crash else 200 if crash: - assert 'The debugger caught an exception in your WSGI application' \ - in r.text + assert "The debugger caught an exception in your WSGI application" in r.text else: - assert r.text == 'hello' + assert r.text == "hello" diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 7bb19a28..616b39c8 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -15,41 +15,44 @@ import pytest from werkzeug import exceptions +from werkzeug._compat import text_type from werkzeug.datastructures import WWWAuthenticate from werkzeug.wrappers import Response -from werkzeug._compat import text_type def test_proxy_exception(): - orig_resp = Response('Hello World') + orig_resp = Response("Hello World") with pytest.raises(exceptions.HTTPException) as excinfo: exceptions.abort(orig_resp) resp = excinfo.value.get_response({}) assert resp is orig_resp - assert resp.get_data() == b'Hello World' - - -@pytest.mark.parametrize('test', [ - (exceptions.BadRequest, 400), - (exceptions.Unauthorized, 401, 'Basic "test realm"'), - (exceptions.Forbidden, 403), - (exceptions.NotFound, 404), - (exceptions.MethodNotAllowed, 405, ['GET', 'HEAD']), - (exceptions.NotAcceptable, 406), - (exceptions.RequestTimeout, 408), - (exceptions.Gone, 410), - (exceptions.LengthRequired, 411), - (exceptions.PreconditionFailed, 412), - (exceptions.RequestEntityTooLarge, 413), - (exceptions.RequestURITooLarge, 414), - (exceptions.UnsupportedMediaType, 415), - (exceptions.UnprocessableEntity, 422), - (exceptions.Locked, 423), - (exceptions.InternalServerError, 500), - (exceptions.NotImplemented, 501), - (exceptions.BadGateway, 502), - (exceptions.ServiceUnavailable, 503) -]) + assert resp.get_data() == b"Hello World" + + +@pytest.mark.parametrize( + "test", + [ + (exceptions.BadRequest, 400), + (exceptions.Unauthorized, 401, 'Basic "test realm"'), + (exceptions.Forbidden, 403), + (exceptions.NotFound, 404), + (exceptions.MethodNotAllowed, 405, ["GET", "HEAD"]), + (exceptions.NotAcceptable, 406), + (exceptions.RequestTimeout, 408), + (exceptions.Gone, 410), + (exceptions.LengthRequired, 411), + (exceptions.PreconditionFailed, 412), + (exceptions.RequestEntityTooLarge, 413), + (exceptions.RequestURITooLarge, 414), + (exceptions.UnsupportedMediaType, 415), + (exceptions.UnprocessableEntity, 422), + (exceptions.Locked, 423), + (exceptions.InternalServerError, 500), + (exceptions.NotImplemented, 501), + (exceptions.BadGateway, 502), + (exceptions.ServiceUnavailable, 503), + ], +) def test_aborter_general(test): exc_type = test[0] args = test[1:] @@ -72,25 +75,26 @@ def test_aborter_custom(): def test_exception_repr(): exc = exceptions.NotFound() assert text_type(exc) == ( - '404 Not Found: The requested URL was not found ' - 'on the server. If you entered the URL manually please check your ' - 'spelling and try again.') + "404 Not Found: The requested URL was not found on the server." + " If you entered the URL manually please check your spelling" + " and try again." + ) assert repr(exc) == "<NotFound '404: Not Found'>" - exc = exceptions.NotFound('Not There') - assert text_type(exc) == '404 Not Found: Not There' + exc = exceptions.NotFound("Not There") + assert text_type(exc) == "404 Not Found: Not There" assert repr(exc) == "<NotFound '404: Not Found'>" - exc = exceptions.HTTPException('An error message') - assert text_type(exc) == '??? Unknown Error: An error message' + exc = exceptions.HTTPException("An error message") + assert text_type(exc) == "??? Unknown Error: An error message" assert repr(exc) == "<HTTPException '???: Unknown Error'>" def test_method_not_allowed_methods(): - exc = exceptions.MethodNotAllowed(['GET', 'HEAD', 'POST']) + exc = exceptions.MethodNotAllowed(["GET", "HEAD", "POST"]) h = dict(exc.get_headers({})) - assert h['Allow'] == 'GET, HEAD, POST' - assert 'The method is not allowed' in exc.get_description() + assert h["Allow"] == "GET, HEAD, POST" + assert "The method is not allowed" in exc.get_description() def test_unauthorized_www_authenticate(): @@ -101,8 +105,8 @@ def test_unauthorized_www_authenticate(): exc = exceptions.Unauthorized(www_authenticate=basic) h = dict(exc.get_headers({})) - assert h['WWW-Authenticate'] == str(basic) + assert h["WWW-Authenticate"] == str(basic) exc = exceptions.Unauthorized(www_authenticate=[digest, basic]) h = dict(exc.get_headers({})) - assert h['WWW-Authenticate'] == ', '.join((str(digest), str(basic))) + assert h["WWW-Authenticate"] == ", ".join((str(digest), str(basic))) diff --git a/tests/test_formparser.py b/tests/test_formparser.py index 5f80ab5f..6d858386 100644 --- a/tests/test_formparser.py +++ b/tests/test_formparser.py @@ -10,106 +10,133 @@ """ import csv import io -import pytest - -from os.path import join, dirname +from os.path import dirname +from os.path import join -from tests import strict_eq +import pytest +from . import strict_eq from werkzeug import formparser -from werkzeug.test import create_environ, Client -from werkzeug.wrappers import Request, Response -from werkzeug.exceptions import RequestEntityTooLarge +from werkzeug._compat import BytesIO +from werkzeug._compat import PY2 from werkzeug.datastructures import MultiDict -from werkzeug.formparser import parse_form_data, FormDataParser -from werkzeug._compat import BytesIO, PY2 +from werkzeug.exceptions import RequestEntityTooLarge +from werkzeug.formparser import FormDataParser +from werkzeug.formparser import parse_form_data +from werkzeug.test import Client +from werkzeug.test import create_environ +from werkzeug.wrappers import Request +from werkzeug.wrappers import Response @Request.application def form_data_consumer(request): - result_object = request.args['object'] - if result_object == 'text': - return Response(repr(request.form['text'])) + result_object = request.args["object"] + if result_object == "text": + return Response(repr(request.form["text"])) f = request.files[result_object] - return Response(b'\n'.join(( - repr(f.filename).encode('ascii'), - repr(f.name).encode('ascii'), - repr(f.content_type).encode('ascii'), - f.stream.read() - ))) + return Response( + b"\n".join( + ( + repr(f.filename).encode("ascii"), + repr(f.name).encode("ascii"), + repr(f.content_type).encode("ascii"), + f.stream.read(), + ) + ) + ) def get_contents(filename): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: return f.read() class TestFormParser(object): - def test_limiting(self): - data = b'foo=Hello+World&bar=baz' - req = Request.from_values(input_stream=BytesIO(data), - content_length=len(data), - content_type='application/x-www-form-urlencoded', - method='POST') + data = b"foo=Hello+World&bar=baz" + req = Request.from_values( + input_stream=BytesIO(data), + content_length=len(data), + content_type="application/x-www-form-urlencoded", + method="POST", + ) req.max_content_length = 400 - strict_eq(req.form['foo'], u'Hello World') + strict_eq(req.form["foo"], u"Hello World") - req = Request.from_values(input_stream=BytesIO(data), - content_length=len(data), - content_type='application/x-www-form-urlencoded', - method='POST') + req = Request.from_values( + input_stream=BytesIO(data), + content_length=len(data), + content_type="application/x-www-form-urlencoded", + method="POST", + ) req.max_form_memory_size = 7 - pytest.raises(RequestEntityTooLarge, lambda: req.form['foo']) + pytest.raises(RequestEntityTooLarge, lambda: req.form["foo"]) - req = Request.from_values(input_stream=BytesIO(data), - content_length=len(data), - content_type='application/x-www-form-urlencoded', - method='POST') + req = Request.from_values( + input_stream=BytesIO(data), + content_length=len(data), + content_type="application/x-www-form-urlencoded", + method="POST", + ) req.max_form_memory_size = 400 - strict_eq(req.form['foo'], u'Hello World') - - data = (b'--foo\r\nContent-Disposition: form-field; name=foo\r\n\r\n' - b'Hello World\r\n' - b'--foo\r\nContent-Disposition: form-field; name=bar\r\n\r\n' - b'bar=baz\r\n--foo--') - req = Request.from_values(input_stream=BytesIO(data), - content_length=len(data), - content_type='multipart/form-data; boundary=foo', - method='POST') + strict_eq(req.form["foo"], u"Hello World") + + data = ( + b"--foo\r\nContent-Disposition: form-field; name=foo\r\n\r\n" + b"Hello World\r\n" + b"--foo\r\nContent-Disposition: form-field; name=bar\r\n\r\n" + b"bar=baz\r\n--foo--" + ) + req = Request.from_values( + input_stream=BytesIO(data), + content_length=len(data), + content_type="multipart/form-data; boundary=foo", + method="POST", + ) req.max_content_length = 4 - pytest.raises(RequestEntityTooLarge, lambda: req.form['foo']) + pytest.raises(RequestEntityTooLarge, lambda: req.form["foo"]) - req = Request.from_values(input_stream=BytesIO(data), - content_length=len(data), - content_type='multipart/form-data; boundary=foo', - method='POST') + req = Request.from_values( + input_stream=BytesIO(data), + content_length=len(data), + content_type="multipart/form-data; boundary=foo", + method="POST", + ) req.max_content_length = 400 - strict_eq(req.form['foo'], u'Hello World') + strict_eq(req.form["foo"], u"Hello World") - req = Request.from_values(input_stream=BytesIO(data), - content_length=len(data), - content_type='multipart/form-data; boundary=foo', - method='POST') + req = Request.from_values( + input_stream=BytesIO(data), + content_length=len(data), + content_type="multipart/form-data; boundary=foo", + method="POST", + ) req.max_form_memory_size = 7 - pytest.raises(RequestEntityTooLarge, lambda: req.form['foo']) + pytest.raises(RequestEntityTooLarge, lambda: req.form["foo"]) - req = Request.from_values(input_stream=BytesIO(data), - content_length=len(data), - content_type='multipart/form-data; boundary=foo', - method='POST') + req = Request.from_values( + input_stream=BytesIO(data), + content_length=len(data), + content_type="multipart/form-data; boundary=foo", + method="POST", + ) req.max_form_memory_size = 400 - strict_eq(req.form['foo'], u'Hello World') + strict_eq(req.form["foo"], u"Hello World") def test_missing_multipart_boundary(self): - data = (b'--foo\r\nContent-Disposition: form-field; name=foo\r\n\r\n' - b'Hello World\r\n' - b'--foo\r\nContent-Disposition: form-field; name=bar\r\n\r\n' - b'bar=baz\r\n--foo--') - req = Request.from_values(input_stream=BytesIO(data), - content_length=len(data), - content_type='multipart/form-data', - method='POST') + data = ( + b"--foo\r\nContent-Disposition: form-field; name=foo\r\n\r\n" + b"Hello World\r\n" + b"--foo\r\nContent-Disposition: form-field; name=bar\r\n\r\n" + b"bar=baz\r\n--foo--" + ) + req = Request.from_values( + input_stream=BytesIO(data), + content_length=len(data), + content_type="multipart/form-data", + method="POST", + ) assert req.form == {} def test_parse_form_data_put_without_content(self): @@ -119,24 +146,23 @@ class TestFormParser(object): # containing an entity-body SHOULD include a Content-Type header field # defining the media type of that body." In the case where either # headers are omitted, parse_form_data should still work. - env = create_environ('/foo', 'http://example.org/', method='PUT') + env = create_environ("/foo", "http://example.org/", method="PUT") stream, form, files = formparser.parse_form_data(env) - strict_eq(stream.read(), b'') + strict_eq(stream.read(), b"") strict_eq(len(form), 0) strict_eq(len(files), 0) def test_parse_form_data_get_without_content(self): - env = create_environ('/foo', 'http://example.org/', method='GET') + env = create_environ("/foo", "http://example.org/", method="GET") stream, form, files = formparser.parse_form_data(env) - strict_eq(stream.read(), b'') + strict_eq(stream.read(), b"") strict_eq(len(form), 0) strict_eq(len(files), 0) @pytest.mark.parametrize( - ("no_spooled", "size"), - ((False, 100), (False, 3000), (True, 100), (True, 3000)), + ("no_spooled", "size"), ((False, 100), (False, 3000), (True, 100), (True, 3000)) ) def test_default_stream_factory(self, no_spooled, size, monkeypatch): if no_spooled: @@ -144,8 +170,7 @@ class TestFormParser(object): data = b"a,b,c\n" * size req = Request.from_values( - data={"foo": (BytesIO(data), "test.txt")}, - method="POST" + data={"foo": (BytesIO(data), "test.txt")}, method="POST" ) file_storage = req.files["foo"] @@ -162,268 +187,338 @@ class TestFormParser(object): file_storage.close() def test_streaming_parse(self): - data = b'x' * (1024 * 600) + data = b"x" * (1024 * 600) class StreamMPP(formparser.MultiPartParser): - def parse(self, file, boundary, content_length): - i = iter(self.parse_lines(file, boundary, content_length, - cap_at_buffer=False)) + i = iter( + self.parse_lines( + file, boundary, content_length, cap_at_buffer=False + ) + ) one = next(i) two = next(i) - return self.cls(()), {'one': one, 'two': two} + return self.cls(()), {"one": one, "two": two} class StreamFDP(formparser.FormDataParser): - - def _sf_parse_multipart(self, stream, mimetype, - content_length, options): + def _sf_parse_multipart(self, stream, mimetype, content_length, options): form, files = StreamMPP( - self.stream_factory, self.charset, self.errors, + self.stream_factory, + self.charset, + self.errors, max_form_memory_size=self.max_form_memory_size, - cls=self.cls).parse(stream, options.get('boundary').encode('ascii'), - content_length) + cls=self.cls, + ).parse(stream, options.get("boundary").encode("ascii"), content_length) return stream, form, files + parse_functions = {} parse_functions.update(formparser.FormDataParser.parse_functions) - parse_functions['multipart/form-data'] = _sf_parse_multipart + parse_functions["multipart/form-data"] = _sf_parse_multipart class StreamReq(Request): form_data_parser_class = StreamFDP - req = StreamReq.from_values(data={'foo': (BytesIO(data), 'test.txt')}, - method='POST') - strict_eq('begin_file', req.files['one'][0]) - strict_eq(('foo', 'test.txt'), req.files['one'][1][1:]) - strict_eq('cont', req.files['two'][0]) - strict_eq(data, req.files['two'][1]) + + req = StreamReq.from_values( + data={"foo": (BytesIO(data), "test.txt")}, method="POST" + ) + strict_eq("begin_file", req.files["one"][0]) + strict_eq(("foo", "test.txt"), req.files["one"][1][1:]) + strict_eq("cont", req.files["two"][0]) + strict_eq(data, req.files["two"][1]) def test_parse_bad_content_type(self): parser = FormDataParser() - assert parser.parse('', 'bad-mime-type', 0) == \ - ('', MultiDict([]), MultiDict([])) + assert parser.parse("", "bad-mime-type", 0) == ( + "", + MultiDict([]), + MultiDict([]), + ) def test_parse_from_environ(self): parser = FormDataParser() - stream, _, _ = parser.parse_from_environ({'wsgi.input': ''}) + stream, _, _ = parser.parse_from_environ({"wsgi.input": ""}) assert stream is not None class TestMultiPart(object): - def test_basic(self): - resources = join(dirname(__file__), 'multipart') + resources = join(dirname(__file__), "multipart") client = Client(form_data_consumer, Response) repository = [ - ('firefox3-2png1txt', '---------------------------186454651713519341951581030105', [ - (u'anchor.png', 'file1', 'image/png', 'file1.png'), - (u'application_edit.png', 'file2', 'image/png', 'file2.png') - ], u'example text'), - ('firefox3-2pnglongtext', '---------------------------14904044739787191031754711748', [ - (u'accept.png', 'file1', 'image/png', 'file1.png'), - (u'add.png', 'file2', 'image/png', 'file2.png') - ], u'--long text\r\n--with boundary\r\n--lookalikes--'), - ('opera8-2png1txt', '----------zEO9jQKmLc2Cq88c23Dx19', [ - (u'arrow_branch.png', 'file1', 'image/png', 'file1.png'), - (u'award_star_bronze_1.png', 'file2', 'image/png', 'file2.png') - ], u'blafasel öäü'), - ('webkit3-2png1txt', '----WebKitFormBoundaryjdSFhcARk8fyGNy6', [ - (u'gtk-apply.png', 'file1', 'image/png', 'file1.png'), - (u'gtk-no.png', 'file2', 'image/png', 'file2.png') - ], u'this is another text with ümläüts'), - ('ie6-2png1txt', '---------------------------7d91b03a20128', [ - (u'file1.png', 'file1', 'image/x-png', 'file1.png'), - (u'file2.png', 'file2', 'image/x-png', 'file2.png') - ], u'ie6 sucks :-/') + ( + "firefox3-2png1txt", + "---------------------------186454651713519341951581030105", + [ + (u"anchor.png", "file1", "image/png", "file1.png"), + (u"application_edit.png", "file2", "image/png", "file2.png"), + ], + u"example text", + ), + ( + "firefox3-2pnglongtext", + "---------------------------14904044739787191031754711748", + [ + (u"accept.png", "file1", "image/png", "file1.png"), + (u"add.png", "file2", "image/png", "file2.png"), + ], + u"--long text\r\n--with boundary\r\n--lookalikes--", + ), + ( + "opera8-2png1txt", + "----------zEO9jQKmLc2Cq88c23Dx19", + [ + (u"arrow_branch.png", "file1", "image/png", "file1.png"), + (u"award_star_bronze_1.png", "file2", "image/png", "file2.png"), + ], + u"blafasel öäü", + ), + ( + "webkit3-2png1txt", + "----WebKitFormBoundaryjdSFhcARk8fyGNy6", + [ + (u"gtk-apply.png", "file1", "image/png", "file1.png"), + (u"gtk-no.png", "file2", "image/png", "file2.png"), + ], + u"this is another text with ümläüts", + ), + ( + "ie6-2png1txt", + "---------------------------7d91b03a20128", + [ + (u"file1.png", "file1", "image/x-png", "file1.png"), + (u"file2.png", "file2", "image/x-png", "file2.png"), + ], + u"ie6 sucks :-/", + ), ] for name, boundary, files, text in repository: folder = join(resources, name) - data = get_contents(join(folder, 'request.txt')) + data = get_contents(join(folder, "request.http")) for filename, field, content_type, fsname in files: response = client.post( - '/?object=' + field, + "/?object=" + field, data=data, content_type='multipart/form-data; boundary="%s"' % boundary, - content_length=len(data)) - lines = response.get_data().split(b'\n', 3) - strict_eq(lines[0], repr(filename).encode('ascii')) - strict_eq(lines[1], repr(field).encode('ascii')) - strict_eq(lines[2], repr(content_type).encode('ascii')) + content_length=len(data), + ) + lines = response.get_data().split(b"\n", 3) + strict_eq(lines[0], repr(filename).encode("ascii")) + strict_eq(lines[1], repr(field).encode("ascii")) + strict_eq(lines[2], repr(content_type).encode("ascii")) strict_eq(lines[3], get_contents(join(folder, fsname))) response = client.post( - '/?object=text', + "/?object=text", data=data, content_type='multipart/form-data; boundary="%s"' % boundary, - content_length=len(data)) - strict_eq(response.get_data(), repr(text).encode('utf-8')) + content_length=len(data), + ) + strict_eq(response.get_data(), repr(text).encode("utf-8")) def test_ie7_unc_path(self): client = Client(form_data_consumer, Response) - data_file = join(dirname(__file__), 'multipart', 'ie7_full_path_request.txt') + data_file = join(dirname(__file__), "multipart", "ie7_full_path_request.http") data = get_contents(data_file) - boundary = '---------------------------7da36d1b4a0164' + boundary = "---------------------------7da36d1b4a0164" response = client.post( - '/?object=cb_file_upload_multiple', + "/?object=cb_file_upload_multiple", data=data, content_type='multipart/form-data; boundary="%s"' % boundary, - content_length=len(data)) - lines = response.get_data().split(b'\n', 3) - strict_eq(lines[0], - repr(u'Sellersburg Town Council Meeting 02-22-2010doc.doc').encode('ascii')) + content_length=len(data), + ) + lines = response.get_data().split(b"\n", 3) + strict_eq( + lines[0], + repr(u"Sellersburg Town Council Meeting 02-22-2010doc.doc").encode("ascii"), + ) def test_end_of_file(self): # This test looks innocent but it was actually timeing out in # the Werkzeug 0.5 release version (#394) data = ( - b'--foo\r\n' + b"--foo\r\n" b'Content-Disposition: form-data; name="test"; filename="test.txt"\r\n' - b'Content-Type: text/plain\r\n\r\n' - b'file contents and no end' + b"Content-Type: text/plain\r\n\r\n" + b"file contents and no end" + ) + data = Request.from_values( + input_stream=BytesIO(data), + content_length=len(data), + content_type="multipart/form-data; boundary=foo", + method="POST", ) - data = Request.from_values(input_stream=BytesIO(data), - content_length=len(data), - content_type='multipart/form-data; boundary=foo', - method='POST') assert not data.files assert not data.form def test_broken(self): data = ( - '--foo\r\n' + "--foo\r\n" 'Content-Disposition: form-data; name="test"; filename="test.txt"\r\n' - 'Content-Transfer-Encoding: base64\r\n' - 'Content-Type: text/plain\r\n\r\n' - 'broken base 64' - '--foo--' - ) - _, form, files = formparser.parse_form_data(create_environ( - data=data, method='POST', content_type='multipart/form-data; boundary=foo' - )) + "Content-Transfer-Encoding: base64\r\n" + "Content-Type: text/plain\r\n\r\n" + "broken base 64" + "--foo--" + ) + _, form, files = formparser.parse_form_data( + create_environ( + data=data, + method="POST", + content_type="multipart/form-data; boundary=foo", + ) + ) assert not files assert not form - pytest.raises(ValueError, formparser.parse_form_data, - create_environ(data=data, method='POST', - content_type='multipart/form-data; boundary=foo'), - silent=False) + pytest.raises( + ValueError, + formparser.parse_form_data, + create_environ( + data=data, + method="POST", + content_type="multipart/form-data; boundary=foo", + ), + silent=False, + ) def test_file_no_content_type(self): data = ( - b'--foo\r\n' + b"--foo\r\n" b'Content-Disposition: form-data; name="test"; filename="test.txt"\r\n\r\n' - b'file contents\r\n--foo--' + b"file contents\r\n--foo--" ) - data = Request.from_values(input_stream=BytesIO(data), - content_length=len(data), - content_type='multipart/form-data; boundary=foo', - method='POST') - assert data.files['test'].filename == 'test.txt' - strict_eq(data.files['test'].read(), b'file contents') + data = Request.from_values( + input_stream=BytesIO(data), + content_length=len(data), + content_type="multipart/form-data; boundary=foo", + method="POST", + ) + assert data.files["test"].filename == "test.txt" + strict_eq(data.files["test"].read(), b"file contents") def test_extra_newline(self): # this test looks innocent but it was actually timeing out in # the Werkzeug 0.5 release version (#394) data = ( - b'\r\n\r\n--foo\r\n' + b"\r\n\r\n--foo\r\n" b'Content-Disposition: form-data; name="foo"\r\n\r\n' - b'a string\r\n' - b'--foo--' + b"a string\r\n" + b"--foo--" + ) + data = Request.from_values( + input_stream=BytesIO(data), + content_length=len(data), + content_type="multipart/form-data; boundary=foo", + method="POST", ) - data = Request.from_values(input_stream=BytesIO(data), - content_length=len(data), - content_type='multipart/form-data; boundary=foo', - method='POST') assert not data.files - strict_eq(data.form['foo'], u'a string') + strict_eq(data.form["foo"], u"a string") def test_headers(self): - data = (b'--foo\r\n' - b'Content-Disposition: form-data; name="foo"; filename="foo.txt"\r\n' - b'X-Custom-Header: blah\r\n' - b'Content-Type: text/plain; charset=utf-8\r\n\r\n' - b'file contents, just the contents\r\n' - b'--foo--') - req = Request.from_values(input_stream=BytesIO(data), - content_length=len(data), - content_type='multipart/form-data; boundary=foo', - method='POST') - foo = req.files['foo'] - strict_eq(foo.mimetype, 'text/plain') - strict_eq(foo.mimetype_params, {'charset': 'utf-8'}) - strict_eq(foo.headers['content-type'], foo.content_type) - strict_eq(foo.content_type, 'text/plain; charset=utf-8') - strict_eq(foo.headers['x-custom-header'], 'blah') + data = ( + b"--foo\r\n" + b'Content-Disposition: form-data; name="foo"; filename="foo.txt"\r\n' + b"X-Custom-Header: blah\r\n" + b"Content-Type: text/plain; charset=utf-8\r\n\r\n" + b"file contents, just the contents\r\n" + b"--foo--" + ) + req = Request.from_values( + input_stream=BytesIO(data), + content_length=len(data), + content_type="multipart/form-data; boundary=foo", + method="POST", + ) + foo = req.files["foo"] + strict_eq(foo.mimetype, "text/plain") + strict_eq(foo.mimetype_params, {"charset": "utf-8"}) + strict_eq(foo.headers["content-type"], foo.content_type) + strict_eq(foo.content_type, "text/plain; charset=utf-8") + strict_eq(foo.headers["x-custom-header"], "blah") def test_nonstandard_line_endings(self): - for nl in b'\n', b'\r', b'\r\n': - data = nl.join(( - b'--foo', - b'Content-Disposition: form-data; name=foo', - b'', - b'this is just bar', - b'--foo', - b'Content-Disposition: form-data; name=bar', - b'', - b'blafasel', - b'--foo--' - )) - req = Request.from_values(input_stream=BytesIO(data), - content_length=len(data), - content_type='multipart/form-data; ' - 'boundary=foo', method='POST') - strict_eq(req.form['foo'], u'this is just bar') - strict_eq(req.form['bar'], u'blafasel') + for nl in b"\n", b"\r", b"\r\n": + data = nl.join( + ( + b"--foo", + b"Content-Disposition: form-data; name=foo", + b"", + b"this is just bar", + b"--foo", + b"Content-Disposition: form-data; name=bar", + b"", + b"blafasel", + b"--foo--", + ) + ) + req = Request.from_values( + input_stream=BytesIO(data), + content_length=len(data), + content_type="multipart/form-data; boundary=foo", + method="POST", + ) + strict_eq(req.form["foo"], u"this is just bar") + strict_eq(req.form["bar"], u"blafasel") def test_failures(self): def parse_multipart(stream, boundary, content_length): parser = formparser.MultiPartParser(content_length) return parser.parse(stream, boundary, content_length) - pytest.raises(ValueError, parse_multipart, BytesIO(), b'broken ', 0) - data = b'--foo\r\n\r\nHello World\r\n--foo--' - pytest.raises(ValueError, parse_multipart, BytesIO(data), b'foo', len(data)) + pytest.raises(ValueError, parse_multipart, BytesIO(), b"broken ", 0) - data = b'--foo\r\nContent-Disposition: form-field; name=foo\r\n' \ - b'Content-Transfer-Encoding: base64\r\n\r\nHello World\r\n--foo--' - pytest.raises(ValueError, parse_multipart, BytesIO(data), b'foo', len(data)) + data = b"--foo\r\n\r\nHello World\r\n--foo--" + pytest.raises(ValueError, parse_multipart, BytesIO(data), b"foo", len(data)) - data = b'--foo\r\nContent-Disposition: form-field; name=foo\r\n\r\nHello World\r\n' - pytest.raises(ValueError, parse_multipart, BytesIO(data), b'foo', len(data)) + data = ( + b"--foo\r\nContent-Disposition: form-field; name=foo\r\n" + b"Content-Transfer-Encoding: base64\r\n\r\nHello World\r\n--foo--" + ) + pytest.raises(ValueError, parse_multipart, BytesIO(data), b"foo", len(data)) - x = formparser.parse_multipart_headers(['foo: bar\r\n', ' x test\r\n']) - strict_eq(x['foo'], 'bar\n x test') - pytest.raises(ValueError, formparser.parse_multipart_headers, - ['foo: bar\r\n', ' x test']) + data = ( + b"--foo\r\nContent-Disposition: form-field; name=foo\r\n\r\nHello World\r\n" + ) + pytest.raises(ValueError, parse_multipart, BytesIO(data), b"foo", len(data)) + + x = formparser.parse_multipart_headers(["foo: bar\r\n", " x test\r\n"]) + strict_eq(x["foo"], "bar\n x test") + pytest.raises( + ValueError, formparser.parse_multipart_headers, ["foo: bar\r\n", " x test"] + ) def test_bad_newline_bad_newline_assumption(self): class ISORequest(Request): - charset = 'latin1' - contents = b'U2vlbmUgbORu' - data = b'--foo\r\nContent-Disposition: form-data; name="test"\r\n' \ - b'Content-Transfer-Encoding: base64\r\n\r\n' + \ - contents + b'\r\n--foo--' - req = ISORequest.from_values(input_stream=BytesIO(data), - content_length=len(data), - content_type='multipart/form-data; boundary=foo', - method='POST') - strict_eq(req.form['test'], u'Sk\xe5ne l\xe4n') + charset = "latin1" + + contents = b"U2vlbmUgbORu" + data = ( + b'--foo\r\nContent-Disposition: form-data; name="test"\r\n' + b"Content-Transfer-Encoding: base64\r\n\r\n" + contents + b"\r\n--foo--" + ) + req = ISORequest.from_values( + input_stream=BytesIO(data), + content_length=len(data), + content_type="multipart/form-data; boundary=foo", + method="POST", + ) + strict_eq(req.form["test"], u"Sk\xe5ne l\xe4n") def test_empty_multipart(self): environ = {} - data = b'--boundary--' - environ['REQUEST_METHOD'] = 'POST' - environ['CONTENT_TYPE'] = 'multipart/form-data; boundary=boundary' - environ['CONTENT_LENGTH'] = str(len(data)) - environ['wsgi.input'] = BytesIO(data) + data = b"--boundary--" + environ["REQUEST_METHOD"] = "POST" + environ["CONTENT_TYPE"] = "multipart/form-data; boundary=boundary" + environ["CONTENT_LENGTH"] = str(len(data)) + environ["wsgi.input"] = BytesIO(data) stream, form, files = parse_form_data(environ, silent=False) rv = stream.read() - assert rv == b'' + assert rv == b"" assert form == MultiDict() assert files == MultiDict() class TestMultiPartParser(object): - def test_constructor_not_pass_stream_factory_and_cls(self): parser = formparser.MultiPartParser() @@ -443,7 +538,7 @@ class TestMultiPartParser(object): data = ( b"--foo\r\n" b"Content-Type: text/plain; charset=utf-8\r\n" - b'Content-Disposition: form-data; name=rfc2231;\r\n' + b"Content-Disposition: form-data; name=rfc2231;\r\n" b" filename*0*=ascii''a%20b%20;\r\n" b" filename*1*=c%20d%20;\r\n" b' filename*2="e f.txt"\r\n\r\n' @@ -453,25 +548,24 @@ class TestMultiPartParser(object): input_stream=BytesIO(data), content_length=len(data), content_type="multipart/form-data; boundary=foo", - method="POST" + method="POST", ) assert request.files["rfc2231"].filename == "a b c d e f.txt" assert request.files["rfc2231"].read() == b"file contents" class TestInternalFunctions(object): - def test_line_parser(self): - assert formparser._line_parse('foo') == ('foo', False) - assert formparser._line_parse('foo\r\n') == ('foo', True) - assert formparser._line_parse('foo\r') == ('foo', True) - assert formparser._line_parse('foo\n') == ('foo', True) + assert formparser._line_parse("foo") == ("foo", False) + assert formparser._line_parse("foo\r\n") == ("foo", True) + assert formparser._line_parse("foo\r") == ("foo", True) + assert formparser._line_parse("foo\n") == ("foo", True) def test_find_terminator(self): - lineiter = iter(b'\n\n\nfoo\nbar\nbaz'.splitlines(True)) + lineiter = iter(b"\n\n\nfoo\nbar\nbaz".splitlines(True)) find_terminator = formparser.MultiPartParser()._find_terminator line = find_terminator(lineiter) - assert line == b'foo' - assert list(lineiter) == [b'bar\n', b'baz'] - assert find_terminator([]) == b'' - assert find_terminator([b'']) == b'' + assert line == b"foo" + assert list(lineiter) == [b"bar\n", b"baz"] + assert find_terminator([]) == b"" + assert find_terminator([b""]) == b"" diff --git a/tests/test_http.py b/tests/test_http.py index b261a605..c6e8309a 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -8,199 +8,222 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import pytest - from datetime import datetime -from tests import strict_eq -from werkzeug._compat import itervalues, wsgi_encoding_dance +import pytest -from werkzeug import http, datastructures +from . import strict_eq +from werkzeug import datastructures +from werkzeug import http +from werkzeug._compat import itervalues +from werkzeug._compat import wsgi_encoding_dance from werkzeug.test import create_environ class TestHTTPUtility(object): - def test_accept(self): - a = http.parse_accept_header('en-us,ru;q=0.5') - assert list(itervalues(a)) == ['en-us', 'ru'] - assert a.best == 'en-us' - assert a.find('ru') == 1 - pytest.raises(ValueError, a.index, 'de') - assert a.to_header() == 'en-us,ru;q=0.5' + a = http.parse_accept_header("en-us,ru;q=0.5") + assert list(itervalues(a)) == ["en-us", "ru"] + assert a.best == "en-us" + assert a.find("ru") == 1 + pytest.raises(ValueError, a.index, "de") + assert a.to_header() == "en-us,ru;q=0.5" def test_mime_accept(self): - a = http.parse_accept_header('text/xml,application/xml,' - 'application/xhtml+xml,' - 'application/foo;quiet=no; bar=baz;q=0.6,' - 'text/html;q=0.9,text/plain;q=0.8,' - 'image/png,*/*;q=0.5', - datastructures.MIMEAccept) - pytest.raises(ValueError, lambda: a['missing']) - assert a['image/png'] == 1 - assert a['text/plain'] == 0.8 - assert a['foo/bar'] == 0.5 - assert a['application/foo;quiet=no; bar=baz'] == 0.6 - assert a[a.find('foo/bar')] == ('*/*', 0.5) + a = http.parse_accept_header( + "text/xml,application/xml," + "application/xhtml+xml," + "application/foo;quiet=no; bar=baz;q=0.6," + "text/html;q=0.9,text/plain;q=0.8," + "image/png,*/*;q=0.5", + datastructures.MIMEAccept, + ) + pytest.raises(ValueError, lambda: a["missing"]) + assert a["image/png"] == 1 + assert a["text/plain"] == 0.8 + assert a["foo/bar"] == 0.5 + assert a["application/foo;quiet=no; bar=baz"] == 0.6 + assert a[a.find("foo/bar")] == ("*/*", 0.5) def test_accept_matches(self): - a = http.parse_accept_header('text/xml,application/xml,application/xhtml+xml,' - 'text/html;q=0.9,text/plain;q=0.8,' - 'image/png', datastructures.MIMEAccept) - assert a.best_match(['text/html', 'application/xhtml+xml']) == \ - 'application/xhtml+xml' - assert a.best_match(['text/html']) == 'text/html' - assert a.best_match(['foo/bar']) is None - assert a.best_match(['foo/bar', 'bar/foo'], default='foo/bar') == 'foo/bar' - assert a.best_match(['application/xml', 'text/xml']) == 'application/xml' + a = http.parse_accept_header( + "text/xml,application/xml,application/xhtml+xml," + "text/html;q=0.9,text/plain;q=0.8," + "image/png", + datastructures.MIMEAccept, + ) + assert ( + a.best_match(["text/html", "application/xhtml+xml"]) + == "application/xhtml+xml" + ) + assert a.best_match(["text/html"]) == "text/html" + assert a.best_match(["foo/bar"]) is None + assert a.best_match(["foo/bar", "bar/foo"], default="foo/bar") == "foo/bar" + assert a.best_match(["application/xml", "text/xml"]) == "application/xml" def test_charset_accept(self): - a = http.parse_accept_header('ISO-8859-1,utf-8;q=0.7,*;q=0.7', - datastructures.CharsetAccept) - assert a['iso-8859-1'] == a['iso8859-1'] - assert a['iso-8859-1'] == 1 - assert a['UTF8'] == 0.7 - assert a['ebcdic'] == 0.7 + a = http.parse_accept_header( + "ISO-8859-1,utf-8;q=0.7,*;q=0.7", datastructures.CharsetAccept + ) + assert a["iso-8859-1"] == a["iso8859-1"] + assert a["iso-8859-1"] == 1 + assert a["UTF8"] == 0.7 + assert a["ebcdic"] == 0.7 def test_language_accept(self): - a = http.parse_accept_header('de-AT,de;q=0.8,en;q=0.5', - datastructures.LanguageAccept) - assert a.best == 'de-AT' - assert 'de_AT' in a - assert 'en' in a - assert a['de-at'] == 1 - assert a['en'] == 0.5 + a = http.parse_accept_header( + "de-AT,de;q=0.8,en;q=0.5", datastructures.LanguageAccept + ) + assert a.best == "de-AT" + assert "de_AT" in a + assert "en" in a + assert a["de-at"] == 1 + assert a["en"] == 0.5 def test_set_header(self): hs = http.parse_set_header('foo, Bar, "Blah baz", Hehe') - assert 'blah baz' in hs - assert 'foobar' not in hs - assert 'foo' in hs - assert list(hs) == ['foo', 'Bar', 'Blah baz', 'Hehe'] - hs.add('Foo') + assert "blah baz" in hs + assert "foobar" not in hs + assert "foo" in hs + assert list(hs) == ["foo", "Bar", "Blah baz", "Hehe"] + hs.add("Foo") assert hs.to_header() == 'foo, Bar, "Blah baz", Hehe' def test_list_header(self): - hl = http.parse_list_header('foo baz, blah') - assert hl == ['foo baz', 'blah'] + hl = http.parse_list_header("foo baz, blah") + assert hl == ["foo baz", "blah"] def test_dict_header(self): d = http.parse_dict_header('foo="bar baz", blah=42') - assert d == {'foo': 'bar baz', 'blah': '42'} + assert d == {"foo": "bar baz", "blah": "42"} def test_cache_control_header(self): - cc = http.parse_cache_control_header('max-age=0, no-cache') + cc = http.parse_cache_control_header("max-age=0, no-cache") assert cc.max_age == 0 assert cc.no_cache - cc = http.parse_cache_control_header('private, community="UCI"', None, - datastructures.ResponseCacheControl) + cc = http.parse_cache_control_header( + 'private, community="UCI"', None, datastructures.ResponseCacheControl + ) assert cc.private - assert cc['community'] == 'UCI' + assert cc["community"] == "UCI" c = datastructures.ResponseCacheControl() assert c.no_cache is None assert c.private is None c.no_cache = True - assert c.no_cache == '*' + assert c.no_cache == "*" c.private = True - assert c.private == '*' + assert c.private == "*" del c.private assert c.private is None - assert c.to_header() == 'no-cache' + assert c.to_header() == "no-cache" def test_authorization_header(self): - a = http.parse_authorization_header('Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==') - assert a.type == 'basic' - assert a.username == u'Aladdin' - assert a.password == u'open sesame' - - a = http.parse_authorization_header('Basic 0YDRg9GB0YHQutC40IE60JHRg9C60LLRiw==') - assert a.type == 'basic' - assert a.username == u'русскиЁ' - assert a.password == u'Буквы' - - a = http.parse_authorization_header('Basic 5pmu6YCa6K+dOuS4reaWhw==') - assert a.type == 'basic' - assert a.username == u'普通话' - assert a.password == u'中文' - - a = http.parse_authorization_header('''Digest username="Mufasa", - realm="testrealm@host.invalid", - nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093", - uri="/dir/index.html", - qop=auth, - nc=00000001, - cnonce="0a4f113b", - response="6629fae49393a05397450978507c4ef1", - opaque="5ccc069c403ebaf9f0171e9517f40e41"''') - assert a.type == 'digest' - assert a.username == 'Mufasa' - assert a.realm == 'testrealm@host.invalid' - assert a.nonce == 'dcd98b7102dd2f0e8b11d0f600bfb0c093' - assert a.uri == '/dir/index.html' - assert a.qop == 'auth' - assert a.nc == '00000001' - assert a.cnonce == '0a4f113b' - assert a.response == '6629fae49393a05397450978507c4ef1' - assert a.opaque == '5ccc069c403ebaf9f0171e9517f40e41' - - a = http.parse_authorization_header('''Digest username="Mufasa", - realm="testrealm@host.invalid", - nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093", - uri="/dir/index.html", - response="e257afa1414a3340d93d30955171dd0e", - opaque="5ccc069c403ebaf9f0171e9517f40e41"''') - assert a.type == 'digest' - assert a.username == 'Mufasa' - assert a.realm == 'testrealm@host.invalid' - assert a.nonce == 'dcd98b7102dd2f0e8b11d0f600bfb0c093' - assert a.uri == '/dir/index.html' - assert a.response == 'e257afa1414a3340d93d30955171dd0e' - assert a.opaque == '5ccc069c403ebaf9f0171e9517f40e41' - - assert http.parse_authorization_header('') is None + a = http.parse_authorization_header("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==") + assert a.type == "basic" + assert a.username == u"Aladdin" + assert a.password == u"open sesame" + + a = http.parse_authorization_header( + "Basic 0YDRg9GB0YHQutC40IE60JHRg9C60LLRiw==" + ) + assert a.type == "basic" + assert a.username == u"русскиЁ" + assert a.password == u"Буквы" + + a = http.parse_authorization_header("Basic 5pmu6YCa6K+dOuS4reaWhw==") + assert a.type == "basic" + assert a.username == u"普通话" + assert a.password == u"中文" + + a = http.parse_authorization_header( + '''Digest username="Mufasa", + realm="testrealm@host.invalid", + nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093", + uri="/dir/index.html", + qop=auth, + nc=00000001, + cnonce="0a4f113b", + response="6629fae49393a05397450978507c4ef1", + opaque="5ccc069c403ebaf9f0171e9517f40e41"''' + ) + assert a.type == "digest" + assert a.username == "Mufasa" + assert a.realm == "testrealm@host.invalid" + assert a.nonce == "dcd98b7102dd2f0e8b11d0f600bfb0c093" + assert a.uri == "/dir/index.html" + assert a.qop == "auth" + assert a.nc == "00000001" + assert a.cnonce == "0a4f113b" + assert a.response == "6629fae49393a05397450978507c4ef1" + assert a.opaque == "5ccc069c403ebaf9f0171e9517f40e41" + + a = http.parse_authorization_header( + '''Digest username="Mufasa", + realm="testrealm@host.invalid", + nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093", + uri="/dir/index.html", + response="e257afa1414a3340d93d30955171dd0e", + opaque="5ccc069c403ebaf9f0171e9517f40e41"''' + ) + assert a.type == "digest" + assert a.username == "Mufasa" + assert a.realm == "testrealm@host.invalid" + assert a.nonce == "dcd98b7102dd2f0e8b11d0f600bfb0c093" + assert a.uri == "/dir/index.html" + assert a.response == "e257afa1414a3340d93d30955171dd0e" + assert a.opaque == "5ccc069c403ebaf9f0171e9517f40e41" + + assert http.parse_authorization_header("") is None assert http.parse_authorization_header(None) is None - assert http.parse_authorization_header('foo') is None + assert http.parse_authorization_header("foo") is None def test_www_authenticate_header(self): wa = http.parse_www_authenticate_header('Basic realm="WallyWorld"') - assert wa.type == 'basic' - assert wa.realm == 'WallyWorld' - wa.realm = 'Foo Bar' + assert wa.type == "basic" + assert wa.realm == "WallyWorld" + wa.realm = "Foo Bar" assert wa.to_header() == 'Basic realm="Foo Bar"' - wa = http.parse_www_authenticate_header('''Digest - realm="testrealm@host.com", - qop="auth,auth-int", - nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093", - opaque="5ccc069c403ebaf9f0171e9517f40e41"''') - assert wa.type == 'digest' - assert wa.realm == 'testrealm@host.com' - assert 'auth' in wa.qop - assert 'auth-int' in wa.qop - assert wa.nonce == 'dcd98b7102dd2f0e8b11d0f600bfb0c093' - assert wa.opaque == '5ccc069c403ebaf9f0171e9517f40e41' + wa = http.parse_www_authenticate_header( + '''Digest + realm="testrealm@host.com", + qop="auth,auth-int", + nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093", + opaque="5ccc069c403ebaf9f0171e9517f40e41"''' + ) + assert wa.type == "digest" + assert wa.realm == "testrealm@host.com" + assert "auth" in wa.qop + assert "auth-int" in wa.qop + assert wa.nonce == "dcd98b7102dd2f0e8b11d0f600bfb0c093" + assert wa.opaque == "5ccc069c403ebaf9f0171e9517f40e41" - wa = http.parse_www_authenticate_header('broken') - assert wa.type == 'broken' + wa = http.parse_www_authenticate_header("broken") + assert wa.type == "broken" - assert not http.parse_www_authenticate_header('').type - assert not http.parse_www_authenticate_header('') + assert not http.parse_www_authenticate_header("").type + assert not http.parse_www_authenticate_header("") def test_etags(self): - assert http.quote_etag('foo') == '"foo"' - assert http.quote_etag('foo', True) == 'W/"foo"' - assert http.unquote_etag('"foo"') == ('foo', False) - assert http.unquote_etag('W/"foo"') == ('foo', True) + assert http.quote_etag("foo") == '"foo"' + assert http.quote_etag("foo", True) == 'W/"foo"' + assert http.unquote_etag('"foo"') == ("foo", False) + assert http.unquote_etag('W/"foo"') == ("foo", True) es = http.parse_etags('"foo", "bar", W/"baz", blar') - assert sorted(es) == ['bar', 'blar', 'foo'] - assert 'foo' in es - assert 'baz' not in es - assert es.contains_weak('baz') - assert 'blar' in es + assert sorted(es) == ["bar", "blar", "foo"] + assert "foo" in es + assert "baz" not in es + assert es.contains_weak("baz") + assert "blar" in es assert es.contains_raw('W/"baz"') assert es.contains_raw('"foo"') - assert sorted(es.to_header().split(', ')) == ['"bar"', '"blar"', '"foo"', 'W/"baz"'] + assert sorted(es.to_header().split(", ")) == [ + '"bar"', + '"blar"', + '"foo"', + 'W/"baz"', + ] def test_etags_nonzero(self): etags = http.parse_etags('W/"foo"') @@ -208,392 +231,433 @@ class TestHTTPUtility(object): assert etags.contains_raw('W/"foo"') def test_parse_date(self): - assert http.parse_date('Sun, 06 Nov 1994 08:49:37 GMT ') == datetime( - 1994, 11, 6, 8, 49, 37) - assert http.parse_date('Sunday, 06-Nov-94 08:49:37 GMT') == datetime(1994, 11, 6, 8, 49, 37) - assert http.parse_date(' Sun Nov 6 08:49:37 1994') == datetime(1994, 11, 6, 8, 49, 37) - assert http.parse_date('foo') is None + assert http.parse_date("Sun, 06 Nov 1994 08:49:37 GMT ") == datetime( + 1994, 11, 6, 8, 49, 37 + ) + assert http.parse_date("Sunday, 06-Nov-94 08:49:37 GMT") == datetime( + 1994, 11, 6, 8, 49, 37 + ) + assert http.parse_date(" Sun Nov 6 08:49:37 1994") == datetime( + 1994, 11, 6, 8, 49, 37 + ) + assert http.parse_date("foo") is None def test_parse_date_overflows(self): - assert http.parse_date(' Sun 02 Feb 1343 08:49:37 GMT') == datetime(1343, 2, 2, 8, 49, 37) - assert http.parse_date('Thu, 01 Jan 1970 00:00:00 GMT') == datetime(1970, 1, 1, 0, 0) - assert http.parse_date('Thu, 33 Jan 1970 00:00:00 GMT') is None + assert http.parse_date(" Sun 02 Feb 1343 08:49:37 GMT") == datetime( + 1343, 2, 2, 8, 49, 37 + ) + assert http.parse_date("Thu, 01 Jan 1970 00:00:00 GMT") == datetime( + 1970, 1, 1, 0, 0 + ) + assert http.parse_date("Thu, 33 Jan 1970 00:00:00 GMT") is None def test_remove_entity_headers(self): now = http.http_date() - headers1 = [('Date', now), ('Content-Type', 'text/html'), ('Content-Length', '0')] + headers1 = [ + ("Date", now), + ("Content-Type", "text/html"), + ("Content-Length", "0"), + ] headers2 = datastructures.Headers(headers1) http.remove_entity_headers(headers1) - assert headers1 == [('Date', now)] + assert headers1 == [("Date", now)] http.remove_entity_headers(headers2) - assert headers2 == datastructures.Headers([(u'Date', now)]) + assert headers2 == datastructures.Headers([(u"Date", now)]) def test_remove_hop_by_hop_headers(self): - headers1 = [('Connection', 'closed'), ('Foo', 'bar'), - ('Keep-Alive', 'wtf')] + headers1 = [("Connection", "closed"), ("Foo", "bar"), ("Keep-Alive", "wtf")] headers2 = datastructures.Headers(headers1) http.remove_hop_by_hop_headers(headers1) - assert headers1 == [('Foo', 'bar')] + assert headers1 == [("Foo", "bar")] http.remove_hop_by_hop_headers(headers2) - assert headers2 == datastructures.Headers([('Foo', 'bar')]) + assert headers2 == datastructures.Headers([("Foo", "bar")]) def test_parse_options_header(self): - assert http.parse_options_header(None) == \ - ('', {}) - assert http.parse_options_header("") == \ - ('', {}) - assert http.parse_options_header(r'something; foo="other\"thing"') == \ - ('something', {'foo': 'other"thing'}) - assert http.parse_options_header(r'something; foo="other\"thing"; meh=42') == \ - ('something', {'foo': 'other"thing', 'meh': '42'}) - assert http.parse_options_header(r'something; foo="other\"thing"; meh=42; bleh') == \ - ('something', {'foo': 'other"thing', 'meh': '42', 'bleh': None}) - assert http.parse_options_header('something; foo="other;thing"; meh=42; bleh') == \ - ('something', {'foo': 'other;thing', 'meh': '42', 'bleh': None}) - assert http.parse_options_header('something; foo="otherthing"; meh=; bleh') == \ - ('something', {'foo': 'otherthing', 'meh': None, 'bleh': None}) + assert http.parse_options_header(None) == ("", {}) + assert http.parse_options_header("") == ("", {}) + assert http.parse_options_header(r'something; foo="other\"thing"') == ( + "something", + {"foo": 'other"thing'}, + ) + assert http.parse_options_header(r'something; foo="other\"thing"; meh=42') == ( + "something", + {"foo": 'other"thing', "meh": "42"}, + ) + assert http.parse_options_header( + r'something; foo="other\"thing"; meh=42; bleh' + ) == ("something", {"foo": 'other"thing', "meh": "42", "bleh": None}) + assert http.parse_options_header( + 'something; foo="other;thing"; meh=42; bleh' + ) == ("something", {"foo": "other;thing", "meh": "42", "bleh": None}) + assert http.parse_options_header('something; foo="otherthing"; meh=; bleh') == ( + "something", + {"foo": "otherthing", "meh": None, "bleh": None}, + ) # Issue #404 - assert http.parse_options_header('multipart/form-data; name="foo bar"; ' - 'filename="bar foo"') == \ - ('multipart/form-data', {'name': 'foo bar', 'filename': 'bar foo'}) + assert http.parse_options_header( + 'multipart/form-data; name="foo bar"; ' 'filename="bar foo"' + ) == ("multipart/form-data", {"name": "foo bar", "filename": "bar foo"}) # Examples from RFC - assert http.parse_options_header('audio/*; q=0.2, audio/basic') == \ - ('audio/*', {'q': '0.2'}) - assert http.parse_options_header('audio/*; q=0.2, audio/basic', multiple=True) == \ - ('audio/*', {'q': '0.2'}, "audio/basic", {}) + assert http.parse_options_header("audio/*; q=0.2, audio/basic") == ( + "audio/*", + {"q": "0.2"}, + ) + assert http.parse_options_header( + "audio/*; q=0.2, audio/basic", multiple=True + ) == ("audio/*", {"q": "0.2"}, "audio/basic", {}) + assert http.parse_options_header( + "text/plain; q=0.5, text/html\n text/x-dvi; q=0.8, text/x-c", + multiple=True, + ) == ( + "text/plain", + {"q": "0.5"}, + "text/html", + {}, + "text/x-dvi", + {"q": "0.8"}, + "text/x-c", + {}, + ) assert http.parse_options_header( - 'text/plain; q=0.5, text/html\n ' - 'text/x-dvi; q=0.8, text/x-c', - multiple=True) == \ - ('text/plain', {'q': '0.5'}, "text/html", {}, - "text/x-dvi", {'q': '0.8'}, "text/x-c", {}) - assert http.parse_options_header('text/plain; q=0.5, text/html\n' - ' ' - 'text/x-dvi; q=0.8, text/x-c') == \ - ('text/plain', {'q': '0.5'}) + "text/plain; q=0.5, text/html\n text/x-dvi; q=0.8, text/x-c" + ) == ("text/plain", {"q": "0.5"}) # Issue #932 assert http.parse_options_header( - 'form-data; ' - 'name="a_file"; ' - 'filename*=UTF-8\'\'' - '"%c2%a3%20and%20%e2%82%ac%20rates"') == \ - ('form-data', {'name': 'a_file', - 'filename': u'\xa3 and \u20ac rates'}) + "form-data; name=\"a_file\"; filename*=UTF-8''" + '"%c2%a3%20and%20%e2%82%ac%20rates"' + ) == ("form-data", {"name": "a_file", "filename": u"\xa3 and \u20ac rates"}) assert http.parse_options_header( - 'form-data; ' - 'name*=UTF-8\'\'"%C5%AAn%C4%ADc%C5%8Dde%CC%BD"; ' - 'filename="some_file.txt"') == \ - ('form-data', {'name': u'\u016an\u012dc\u014dde\u033d', - 'filename': 'some_file.txt'}) + "form-data; name*=UTF-8''\"%C5%AAn%C4%ADc%C5%8Dde%CC%BD\"; " + 'filename="some_file.txt"' + ) == ( + "form-data", + {"name": u"\u016an\u012dc\u014dde\u033d", "filename": "some_file.txt"}, + ) def test_parse_options_header_value_with_quotes(self): assert http.parse_options_header( 'form-data; name="file"; filename="t\'es\'t.txt"' - ) == ('form-data', {'name': 'file', 'filename': "t'es't.txt"}) + ) == ("form-data", {"name": "file", "filename": "t'es't.txt"}) assert http.parse_options_header( - 'form-data; name="file"; filename*=UTF-8\'\'"\'🐍\'.txt"' - ) == ('form-data', {'name': 'file', 'filename': u"'🐍'.txt"}) + "form-data; name=\"file\"; filename*=UTF-8''\"'🐍'.txt\"" + ) == ("form-data", {"name": "file", "filename": u"'🐍'.txt"}) def test_parse_options_header_broken_values(self): # Issue #995 - assert http.parse_options_header(' ') == ('', {}) - assert http.parse_options_header(' , ') == ('', {}) - assert http.parse_options_header(' ; ') == ('', {}) - assert http.parse_options_header(' ,; ') == ('', {}) - assert http.parse_options_header(' , a ') == ('', {}) - assert http.parse_options_header(' ; a ') == ('', {}) + assert http.parse_options_header(" ") == ("", {}) + assert http.parse_options_header(" , ") == ("", {}) + assert http.parse_options_header(" ; ") == ("", {}) + assert http.parse_options_header(" ,; ") == ("", {}) + assert http.parse_options_header(" , a ") == ("", {}) + assert http.parse_options_header(" ; a ") == ("", {}) def test_dump_options_header(self): - assert http.dump_options_header('foo', {'bar': 42}) == \ - 'foo; bar=42' - assert http.dump_options_header('foo', {'bar': 42, 'fizz': None}) in \ - ('foo; bar=42; fizz', 'foo; fizz; bar=42') + assert http.dump_options_header("foo", {"bar": 42}) == "foo; bar=42" + assert http.dump_options_header("foo", {"bar": 42, "fizz": None}) in ( + "foo; bar=42; fizz", + "foo; fizz; bar=42", + ) def test_dump_header(self): - assert http.dump_header([1, 2, 3]) == '1, 2, 3' + assert http.dump_header([1, 2, 3]) == "1, 2, 3" assert http.dump_header([1, 2, 3], allow_token=False) == '"1", "2", "3"' - assert http.dump_header({'foo': 'bar'}, allow_token=False) == 'foo="bar"' - assert http.dump_header({'foo': 'bar'}) == 'foo=bar' + assert http.dump_header({"foo": "bar"}, allow_token=False) == 'foo="bar"' + assert http.dump_header({"foo": "bar"}) == "foo=bar" def test_is_resource_modified(self): env = create_environ() # ignore POST - env['REQUEST_METHOD'] = 'POST' - assert not http.is_resource_modified(env, etag='testing') - env['REQUEST_METHOD'] = 'GET' + env["REQUEST_METHOD"] = "POST" + assert not http.is_resource_modified(env, etag="testing") + env["REQUEST_METHOD"] = "GET" # etagify from data - pytest.raises(TypeError, http.is_resource_modified, env, - data='42', etag='23') - env['HTTP_IF_NONE_MATCH'] = http.generate_etag(b'awesome') - assert not http.is_resource_modified(env, data=b'awesome') + pytest.raises(TypeError, http.is_resource_modified, env, data="42", etag="23") + env["HTTP_IF_NONE_MATCH"] = http.generate_etag(b"awesome") + assert not http.is_resource_modified(env, data=b"awesome") - env['HTTP_IF_MODIFIED_SINCE'] = http.http_date(datetime(2008, 1, 1, 12, 30)) - assert not http.is_resource_modified(env, - last_modified=datetime(2008, 1, 1, 12, 00)) - assert http.is_resource_modified(env, - last_modified=datetime(2008, 1, 1, 13, 00)) + env["HTTP_IF_MODIFIED_SINCE"] = http.http_date(datetime(2008, 1, 1, 12, 30)) + assert not http.is_resource_modified( + env, last_modified=datetime(2008, 1, 1, 12, 00) + ) + assert http.is_resource_modified( + env, last_modified=datetime(2008, 1, 1, 13, 00) + ) def test_is_resource_modified_for_range_requests(self): env = create_environ() - env['HTTP_IF_MODIFIED_SINCE'] = http.http_date(datetime(2008, 1, 1, 12, 30)) - env['HTTP_IF_RANGE'] = http.generate_etag(b'awesome_if_range') + env["HTTP_IF_MODIFIED_SINCE"] = http.http_date(datetime(2008, 1, 1, 12, 30)) + env["HTTP_IF_RANGE"] = http.generate_etag(b"awesome_if_range") # Range header not present, so If-Range should be ignored - assert not http.is_resource_modified(env, data=b'not_the_same', - ignore_if_range=False, - last_modified=datetime(2008, 1, 1, 12, 30)) - - env['HTTP_RANGE'] = '' - assert not http.is_resource_modified(env, data=b'awesome_if_range', - ignore_if_range=False) - assert http.is_resource_modified(env, data=b'not_the_same', - ignore_if_range=False) - - env['HTTP_IF_RANGE'] = http.http_date(datetime(2008, 1, 1, 13, 30)) - assert http.is_resource_modified(env, last_modified=datetime(2008, 1, 1, 14, 00), - ignore_if_range=False) - assert not http.is_resource_modified(env, last_modified=datetime(2008, 1, 1, 13, 30), - ignore_if_range=False) - assert http.is_resource_modified(env, last_modified=datetime(2008, 1, 1, 13, 30), - ignore_if_range=True) + assert not http.is_resource_modified( + env, + data=b"not_the_same", + ignore_if_range=False, + last_modified=datetime(2008, 1, 1, 12, 30), + ) + + env["HTTP_RANGE"] = "" + assert not http.is_resource_modified( + env, data=b"awesome_if_range", ignore_if_range=False + ) + assert http.is_resource_modified( + env, data=b"not_the_same", ignore_if_range=False + ) + + env["HTTP_IF_RANGE"] = http.http_date(datetime(2008, 1, 1, 13, 30)) + assert http.is_resource_modified( + env, last_modified=datetime(2008, 1, 1, 14, 00), ignore_if_range=False + ) + assert not http.is_resource_modified( + env, last_modified=datetime(2008, 1, 1, 13, 30), ignore_if_range=False + ) + assert http.is_resource_modified( + env, last_modified=datetime(2008, 1, 1, 13, 30), ignore_if_range=True + ) def test_date_formatting(self): - assert http.cookie_date(0) == 'Thu, 01-Jan-1970 00:00:00 GMT' - assert http.cookie_date(datetime(1970, 1, 1)) == 'Thu, 01-Jan-1970 00:00:00 GMT' - assert http.http_date(0) == 'Thu, 01 Jan 1970 00:00:00 GMT' - assert http.http_date(datetime(1970, 1, 1)) == 'Thu, 01 Jan 1970 00:00:00 GMT' + assert http.cookie_date(0) == "Thu, 01-Jan-1970 00:00:00 GMT" + assert http.cookie_date(datetime(1970, 1, 1)) == "Thu, 01-Jan-1970 00:00:00 GMT" + assert http.http_date(0) == "Thu, 01 Jan 1970 00:00:00 GMT" + assert http.http_date(datetime(1970, 1, 1)) == "Thu, 01 Jan 1970 00:00:00 GMT" def test_cookies(self): strict_eq( - dict(http.parse_cookie('dismiss-top=6; CP=null*; PHPSESSID=0a539d42abc001cd' - 'c762809248d4beed; a=42; b="\\\";"')), + dict( + http.parse_cookie( + "dismiss-top=6; CP=null*; PHPSESSID=0a539d42abc001cd" + 'c762809248d4beed; a=42; b="\\";"' + ) + ), { - 'CP': u'null*', - 'PHPSESSID': u'0a539d42abc001cdc762809248d4beed', - 'a': u'42', - 'dismiss-top': u'6', - 'b': u'\";' - } - ) - rv = http.dump_cookie('foo', 'bar baz blub', 360, httponly=True, - sync_expires=False) + "CP": u"null*", + "PHPSESSID": u"0a539d42abc001cdc762809248d4beed", + "a": u"42", + "dismiss-top": u"6", + "b": u'";', + }, + ) + rv = http.dump_cookie( + "foo", "bar baz blub", 360, httponly=True, sync_expires=False + ) assert type(rv) is str - assert set(rv.split('; ')) == set(['HttpOnly', 'Max-Age=360', - 'Path=/', 'foo="bar baz blub"']) + assert set(rv.split("; ")) == { + "HttpOnly", + "Max-Age=360", + "Path=/", + 'foo="bar baz blub"', + } - strict_eq(dict(http.parse_cookie('fo234{=bar; blub=Blah')), - {'fo234{': u'bar', 'blub': u'Blah'}) + strict_eq( + dict(http.parse_cookie("fo234{=bar; blub=Blah")), + {"fo234{": u"bar", "blub": u"Blah"}, + ) - strict_eq(http.dump_cookie('key', 'xxx/'), 'key=xxx/; Path=/') - strict_eq(http.dump_cookie('key', 'xxx='), 'key=xxx=; Path=/') + strict_eq(http.dump_cookie("key", "xxx/"), "key=xxx/; Path=/") + strict_eq(http.dump_cookie("key", "xxx="), "key=xxx=; Path=/") def test_bad_cookies(self): strict_eq( - dict(http.parse_cookie( - 'first=IamTheFirst ; a=1; oops ; a=2 ;second = andMeTwo;' - )), - { - 'first': u'IamTheFirst', - 'a': u'2', - 'oops': u'', - 'second': u'andMeTwo', - } + dict( + http.parse_cookie( + "first=IamTheFirst ; a=1; oops ; a=2 ;second = andMeTwo;" + ) + ), + {"first": u"IamTheFirst", "a": u"2", "oops": u"", "second": u"andMeTwo"}, ) def test_empty_keys_are_ignored(self): strict_eq( - dict(http.parse_cookie( - 'first=IamTheFirst ; a=1; a=2 ;second=andMeTwo; ; ' - )), - { - 'first': u'IamTheFirst', - 'a': u'2', - 'second': u'andMeTwo' - } + dict( + http.parse_cookie("first=IamTheFirst ; a=1; a=2 ;second=andMeTwo; ; ") + ), + {"first": u"IamTheFirst", "a": u"2", "second": u"andMeTwo"}, ) def test_cookie_quoting(self): val = http.dump_cookie("foo", "?foo") strict_eq(val, 'foo="?foo"; Path=/') - strict_eq(dict(http.parse_cookie(val)), {'foo': u'?foo'}) + strict_eq(dict(http.parse_cookie(val)), {"foo": u"?foo"}) - strict_eq(dict(http.parse_cookie(r'foo="foo\054bar"')), - {'foo': u'foo,bar'}) + strict_eq(dict(http.parse_cookie(r'foo="foo\054bar"')), {"foo": u"foo,bar"}) def test_cookie_domain_resolving(self): - val = http.dump_cookie('foo', 'bar', domain=u'\N{SNOWMAN}.com') - strict_eq(val, 'foo=bar; Domain=xn--n3h.com; Path=/') + val = http.dump_cookie("foo", "bar", domain=u"\N{SNOWMAN}.com") + strict_eq(val, "foo=bar; Domain=xn--n3h.com; Path=/") def test_cookie_unicode_dumping(self): - val = http.dump_cookie('foo', u'\N{SNOWMAN}') + val = http.dump_cookie("foo", u"\N{SNOWMAN}") h = datastructures.Headers() - h.add('Set-Cookie', val) - assert h['Set-Cookie'] == 'foo="\\342\\230\\203"; Path=/' + h.add("Set-Cookie", val) + assert h["Set-Cookie"] == 'foo="\\342\\230\\203"; Path=/' - cookies = http.parse_cookie(h['Set-Cookie']) - assert cookies['foo'] == u'\N{SNOWMAN}' + cookies = http.parse_cookie(h["Set-Cookie"]) + assert cookies["foo"] == u"\N{SNOWMAN}" def test_cookie_unicode_keys(self): # Yes, this is technically against the spec but happens - val = http.dump_cookie(u'fö', u'fö') - assert val == wsgi_encoding_dance(u'fö="f\\303\\266"; Path=/', 'utf-8') + val = http.dump_cookie(u"fö", u"fö") + assert val == wsgi_encoding_dance(u'fö="f\\303\\266"; Path=/', "utf-8") cookies = http.parse_cookie(val) - assert cookies[u'fö'] == u'fö' + assert cookies[u"fö"] == u"fö" def test_cookie_unicode_parsing(self): # This is actually a correct test. This is what is being submitted # by firefox if you set an unicode cookie and we get the cookie sent # in on Python 3 under PEP 3333. - cookies = http.parse_cookie(u'fö=fö') - assert cookies[u'fö'] == u'fö' + cookies = http.parse_cookie(u"fö=fö") + assert cookies[u"fö"] == u"fö" def test_cookie_domain_encoding(self): - val = http.dump_cookie('foo', 'bar', domain=u'\N{SNOWMAN}.com') - strict_eq(val, 'foo=bar; Domain=xn--n3h.com; Path=/') + val = http.dump_cookie("foo", "bar", domain=u"\N{SNOWMAN}.com") + strict_eq(val, "foo=bar; Domain=xn--n3h.com; Path=/") - val = http.dump_cookie('foo', 'bar', domain=u'.\N{SNOWMAN}.com') - strict_eq(val, 'foo=bar; Domain=.xn--n3h.com; Path=/') + val = http.dump_cookie("foo", "bar", domain=u".\N{SNOWMAN}.com") + strict_eq(val, "foo=bar; Domain=.xn--n3h.com; Path=/") - val = http.dump_cookie('foo', 'bar', domain=u'.foo.com') - strict_eq(val, 'foo=bar; Domain=.foo.com; Path=/') + val = http.dump_cookie("foo", "bar", domain=u".foo.com") + strict_eq(val, "foo=bar; Domain=.foo.com; Path=/") def test_cookie_maxsize(self, recwarn): - val = http.dump_cookie('foo', 'bar' * 1360 + 'b') + val = http.dump_cookie("foo", "bar" * 1360 + "b") assert len(recwarn) == 0 assert len(val) == 4093 - http.dump_cookie('foo', 'bar' * 1360 + 'ba') + http.dump_cookie("foo", "bar" * 1360 + "ba") assert len(recwarn) == 1 w = recwarn.pop() - assert 'cookie is too large' in str(w.message) + assert "cookie is too large" in str(w.message) - http.dump_cookie('foo', b'w' * 502, max_size=512) + http.dump_cookie("foo", b"w" * 502, max_size=512) assert len(recwarn) == 1 w = recwarn.pop() - assert 'the limit is 512 bytes' in str(w.message) - - @pytest.mark.parametrize('input, expected', [ - ('strict', 'foo=bar; Path=/; SameSite=Strict'), - ('lax', 'foo=bar; Path=/; SameSite=Lax'), - (None, 'foo=bar; Path=/'), - ]) + assert "the limit is 512 bytes" in str(w.message) + + @pytest.mark.parametrize( + "input, expected", + [ + ("strict", "foo=bar; Path=/; SameSite=Strict"), + ("lax", "foo=bar; Path=/; SameSite=Lax"), + (None, "foo=bar; Path=/"), + ], + ) def test_cookie_samesite_attribute(self, input, expected): - val = http.dump_cookie('foo', 'bar', samesite=input) + val = http.dump_cookie("foo", "bar", samesite=input) strict_eq(val, expected) class TestRange(object): - def test_if_range_parsing(self): rv = http.parse_if_range_header('"Test"') - assert rv.etag == 'Test' + assert rv.etag == "Test" assert rv.date is None assert rv.to_header() == '"Test"' # weak information is dropped rv = http.parse_if_range_header('W/"Test"') - assert rv.etag == 'Test' + assert rv.etag == "Test" assert rv.date is None assert rv.to_header() == '"Test"' # broken etags are supported too - rv = http.parse_if_range_header('bullshit') - assert rv.etag == 'bullshit' + rv = http.parse_if_range_header("bullshit") + assert rv.etag == "bullshit" assert rv.date is None assert rv.to_header() == '"bullshit"' - rv = http.parse_if_range_header('Thu, 01 Jan 1970 00:00:00 GMT') + rv = http.parse_if_range_header("Thu, 01 Jan 1970 00:00:00 GMT") assert rv.etag is None assert rv.date == datetime(1970, 1, 1) - assert rv.to_header() == 'Thu, 01 Jan 1970 00:00:00 GMT' + assert rv.to_header() == "Thu, 01 Jan 1970 00:00:00 GMT" - for x in '', None: + for x in "", None: rv = http.parse_if_range_header(x) assert rv.etag is None assert rv.date is None - assert rv.to_header() == '' + assert rv.to_header() == "" def test_range_parsing(self): - rv = http.parse_range_header('bytes=52') + rv = http.parse_range_header("bytes=52") assert rv is None - rv = http.parse_range_header('bytes=52-') - assert rv.units == 'bytes' + rv = http.parse_range_header("bytes=52-") + assert rv.units == "bytes" assert rv.ranges == [(52, None)] - assert rv.to_header() == 'bytes=52-' + assert rv.to_header() == "bytes=52-" - rv = http.parse_range_header('bytes=52-99') - assert rv.units == 'bytes' + rv = http.parse_range_header("bytes=52-99") + assert rv.units == "bytes" assert rv.ranges == [(52, 100)] - assert rv.to_header() == 'bytes=52-99' + assert rv.to_header() == "bytes=52-99" - rv = http.parse_range_header('bytes=52-99,-1000') - assert rv.units == 'bytes' + rv = http.parse_range_header("bytes=52-99,-1000") + assert rv.units == "bytes" assert rv.ranges == [(52, 100), (-1000, None)] - assert rv.to_header() == 'bytes=52-99,-1000' + assert rv.to_header() == "bytes=52-99,-1000" - rv = http.parse_range_header('bytes = 1 - 100') - assert rv.units == 'bytes' + rv = http.parse_range_header("bytes = 1 - 100") + assert rv.units == "bytes" assert rv.ranges == [(1, 101)] - assert rv.to_header() == 'bytes=1-100' + assert rv.to_header() == "bytes=1-100" - rv = http.parse_range_header('AWesomes=0-999') - assert rv.units == 'awesomes' + rv = http.parse_range_header("AWesomes=0-999") + assert rv.units == "awesomes" assert rv.ranges == [(0, 1000)] - assert rv.to_header() == 'awesomes=0-999' + assert rv.to_header() == "awesomes=0-999" - rv = http.parse_range_header('bytes=-') + rv = http.parse_range_header("bytes=-") assert rv is None - rv = http.parse_range_header('bytes=bad') + rv = http.parse_range_header("bytes=bad") assert rv is None - rv = http.parse_range_header('bytes=bad-1') + rv = http.parse_range_header("bytes=bad-1") assert rv is None - rv = http.parse_range_header('bytes=-bad') + rv = http.parse_range_header("bytes=-bad") assert rv is None - rv = http.parse_range_header('bytes=52-99, bad') + rv = http.parse_range_header("bytes=52-99, bad") assert rv is None def test_content_range_parsing(self): - rv = http.parse_content_range_header('bytes 0-98/*') - assert rv.units == 'bytes' + rv = http.parse_content_range_header("bytes 0-98/*") + assert rv.units == "bytes" assert rv.start == 0 assert rv.stop == 99 assert rv.length is None - assert rv.to_header() == 'bytes 0-98/*' + assert rv.to_header() == "bytes 0-98/*" - rv = http.parse_content_range_header('bytes 0-98/*asdfsa') + rv = http.parse_content_range_header("bytes 0-98/*asdfsa") assert rv is None - rv = http.parse_content_range_header('bytes 0-99/100') - assert rv.to_header() == 'bytes 0-99/100' + rv = http.parse_content_range_header("bytes 0-99/100") + assert rv.to_header() == "bytes 0-99/100" rv.start = None rv.stop = None - assert rv.units == 'bytes' - assert rv.to_header() == 'bytes */100' + assert rv.units == "bytes" + assert rv.to_header() == "bytes */100" - rv = http.parse_content_range_header('bytes */100') + rv = http.parse_content_range_header("bytes */100") assert rv.start is None assert rv.stop is None assert rv.length == 100 - assert rv.units == 'bytes' + assert rv.units == "bytes" class TestRegression(object): - def test_best_match_works(self): # was a bug in 0.6 - rv = http.parse_accept_header('foo=,application/xml,application/xhtml+xml,' - 'text/html;q=0.9,text/plain;q=0.8,' - 'image/png,*/*;q=0.5', - datastructures.MIMEAccept).best_match(['foo/bar']) - assert rv == 'foo/bar' + rv = http.parse_accept_header( + "foo=,application/xml,application/xhtml+xml," + "text/html;q=0.9,text/plain;q=0.8," + "image/png,*/*;q=0.5", + datastructures.MIMEAccept, + ).best_match(["foo/bar"]) + assert rv == "foo/bar" diff --git a/tests/test_internal.py b/tests/test_internal.py index e272b58e..ca2f92cb 100644 --- a/tests/test_internal.py +++ b/tests/test_internal.py @@ -8,15 +8,16 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import pytest - from datetime import datetime -from warnings import filterwarnings, resetwarnings +from warnings import filterwarnings +from warnings import resetwarnings -from werkzeug.wrappers import Request, Response +import pytest from werkzeug import _internal as internal from werkzeug.test import create_environ +from werkzeug.wrappers import Request +from werkzeug.wrappers import Response def test_date_to_unix(): @@ -28,44 +29,44 @@ def test_date_to_unix(): def test_easteregg(): - req = Request.from_values('/?macgybarchakku') + req = Request.from_values("/?macgybarchakku") resp = Response.force_type(internal._easteregg(None), req) - assert b'About Werkzeug' in resp.get_data() - assert b'the Swiss Army knife of Python web development' in resp.get_data() + assert b"About Werkzeug" in resp.get_data() + assert b"the Swiss Army knife of Python web development" in resp.get_data() def test_wrapper_internals(): - req = Request.from_values(data={'foo': 'bar'}, method='POST') + req = Request.from_values(data={"foo": "bar"}, method="POST") req._load_form_data() - assert req.form.to_dict() == {'foo': 'bar'} + assert req.form.to_dict() == {"foo": "bar"} # second call does not break req._load_form_data() - assert req.form.to_dict() == {'foo': 'bar'} + assert req.form.to_dict() == {"foo": "bar"} # check reprs assert repr(req) == "<Request 'http://localhost/' [POST]>" resp = Response() - assert repr(resp) == '<Response 0 bytes [200 OK]>' - resp.set_data('Hello World!') - assert repr(resp) == '<Response 12 bytes [200 OK]>' - resp.response = iter(['Test']) - assert repr(resp) == '<Response streamed [200 OK]>' + assert repr(resp) == "<Response 0 bytes [200 OK]>" + resp.set_data("Hello World!") + assert repr(resp) == "<Response 12 bytes [200 OK]>" + resp.response = iter(["Test"]) + assert repr(resp) == "<Response streamed [200 OK]>" # unicode data does not set content length - response = Response([u'Hällo Wörld']) + response = Response([u"Hällo Wörld"]) headers = response.get_wsgi_headers(create_environ()) - assert u'Content-Length' not in headers + assert u"Content-Length" not in headers - response = Response([u'Hällo Wörld'.encode('utf-8')]) + response = Response([u"Hällo Wörld".encode("utf-8")]) headers = response.get_wsgi_headers(create_environ()) - assert u'Content-Length' in headers + assert u"Content-Length" in headers # check for internal warnings - filterwarnings('error', category=Warning) + filterwarnings("error", category=Warning) response = Response() environ = create_environ() - response.response = 'What the...?' + response.response = "What the...?" pytest.raises(Warning, lambda: list(response.iter_encoded())) pytest.raises(Warning, lambda: list(response.get_app_iter(environ))) response.direct_passthrough = True diff --git a/tests/test_local.py b/tests/test_local.py index 26f63eba..9bbd0fe8 100644 --- a/tests/test_local.py +++ b/tests/test_local.py @@ -8,13 +8,13 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import pytest - -import time import copy +import time from functools import partial from threading import Thread +import pytest + from werkzeug import local @@ -28,8 +28,8 @@ def test_basic_local(): ns.foo = idx time.sleep(0.02) values.append(ns.foo) - threads = [Thread(target=value_setter, args=(x,)) - for x in [1, 2, 3]] + + threads = [Thread(target=value_setter, args=(x,)) for x in [1, 2, 3]] for thread in threads: thread.start() time.sleep(0.2) @@ -37,6 +37,7 @@ def test_basic_local(): def delfoo(): del ns.foo + delfoo() pytest.raises(AttributeError, lambda: ns.foo) pytest.raises(AttributeError, delfoo) @@ -48,7 +49,7 @@ def test_local_release(): ns = local.Local() ns.foo = 42 local.release_local(ns) - assert not hasattr(ns, 'foo') + assert not hasattr(ns, "foo") ls = local.LocalStack() ls.push(42) @@ -122,7 +123,7 @@ def test_local_stack(): assert proxy == (1, 2) ls.pop() ls.pop() - assert repr(proxy) == '<LocalProxy unbound>' + assert repr(proxy) == "<LocalProxy unbound>" assert ident not in ls._local.__storage__ @@ -144,18 +145,18 @@ def test_custom_idents(): local.LocalManager([ns, stack], ident_func=lambda: ident) ns.foo = 42 - stack.push({'foo': 42}) + stack.push({"foo": 42}) ident = 1 ns.foo = 23 - stack.push({'foo': 23}) + stack.push({"foo": 23}) ident = 0 assert ns.foo == 42 - assert stack.top['foo'] == 42 + assert stack.top["foo"] == 42 stack.pop() assert stack.top is None ident = 1 assert ns.foo == 23 - assert stack.top['foo'] == 23 + assert stack.top["foo"] == 23 stack.pop() assert stack.top is None @@ -169,6 +170,7 @@ def test_deepcopy_on_proxy(): def __deepcopy__(self, memo): return self + f = Foo() p = local.LocalProxy(lambda: f) assert p.attr == 42 @@ -186,7 +188,7 @@ def test_deepcopy_on_proxy(): def test_local_proxy_wrapped_attribute(): class SomeClassWithWrapped(object): - __wrapped__ = 'wrapped' + __wrapped__ = "wrapped" def lookup_func(): return 42 @@ -203,5 +205,5 @@ def test_local_proxy_wrapped_attribute(): ns.foo = SomeClassWithWrapped() ns.bar = 42 - assert ns('foo').__wrapped__ == 'wrapped' - pytest.raises(AttributeError, lambda: ns('bar').__wrapped__) + assert ns("foo").__wrapped__ == "wrapped" + pytest.raises(AttributeError, lambda: ns("bar").__wrapped__) diff --git a/tests/test_routing.py b/tests/test_routing.py index fa604ee0..835ca684 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -8,286 +8,313 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import pytest - import uuid -from tests import strict_eq +import pytest +from . import strict_eq from werkzeug import routing as r -from werkzeug.wrappers import Response -from werkzeug.datastructures import ImmutableDict, MultiDict +from werkzeug.datastructures import ImmutableDict +from werkzeug.datastructures import MultiDict from werkzeug.test import create_environ +from werkzeug.wrappers import Response def test_basic_routing(): - map = r.Map([ - r.Rule('/', endpoint='index'), - r.Rule('/foo', endpoint='foo'), - r.Rule('/bar/', endpoint='bar') - ]) - adapter = map.bind('example.org', '/') - assert adapter.match('/') == ('index', {}) - assert adapter.match('/foo') == ('foo', {}) - assert adapter.match('/bar/') == ('bar', {}) - pytest.raises(r.RequestRedirect, lambda: adapter.match('/bar')) - pytest.raises(r.NotFound, lambda: adapter.match('/blub')) - - adapter = map.bind('example.org', '/test') + map = r.Map( + [ + r.Rule("/", endpoint="index"), + r.Rule("/foo", endpoint="foo"), + r.Rule("/bar/", endpoint="bar"), + ] + ) + adapter = map.bind("example.org", "/") + assert adapter.match("/") == ("index", {}) + assert adapter.match("/foo") == ("foo", {}) + assert adapter.match("/bar/") == ("bar", {}) + pytest.raises(r.RequestRedirect, lambda: adapter.match("/bar")) + pytest.raises(r.NotFound, lambda: adapter.match("/blub")) + + adapter = map.bind("example.org", "/test") with pytest.raises(r.RequestRedirect) as excinfo: - adapter.match('/bar') - assert excinfo.value.new_url == 'http://example.org/test/bar/' + adapter.match("/bar") + assert excinfo.value.new_url == "http://example.org/test/bar/" - adapter = map.bind('example.org', '/') + adapter = map.bind("example.org", "/") with pytest.raises(r.RequestRedirect) as excinfo: - adapter.match('/bar') - assert excinfo.value.new_url == 'http://example.org/bar/' + adapter.match("/bar") + assert excinfo.value.new_url == "http://example.org/bar/" - adapter = map.bind('example.org', '/') + adapter = map.bind("example.org", "/") with pytest.raises(r.RequestRedirect) as excinfo: - adapter.match('/bar', query_args={'aha': 'muhaha'}) - assert excinfo.value.new_url == 'http://example.org/bar/?aha=muhaha' + adapter.match("/bar", query_args={"aha": "muhaha"}) + assert excinfo.value.new_url == "http://example.org/bar/?aha=muhaha" - adapter = map.bind('example.org', '/') + adapter = map.bind("example.org", "/") with pytest.raises(r.RequestRedirect) as excinfo: - adapter.match('/bar', query_args='aha=muhaha') - assert excinfo.value.new_url == 'http://example.org/bar/?aha=muhaha' + adapter.match("/bar", query_args="aha=muhaha") + assert excinfo.value.new_url == "http://example.org/bar/?aha=muhaha" - adapter = map.bind_to_environ(create_environ('/bar?foo=bar', - 'http://example.org/')) + adapter = map.bind_to_environ(create_environ("/bar?foo=bar", "http://example.org/")) with pytest.raises(r.RequestRedirect) as excinfo: adapter.match() - assert excinfo.value.new_url == 'http://example.org/bar/?foo=bar' + assert excinfo.value.new_url == "http://example.org/bar/?foo=bar" def test_strict_slashes_redirect(): - map = r.Map([ - r.Rule('/bar/', endpoint='get', methods=["GET"]), - r.Rule('/bar', endpoint='post', methods=["POST"]), - r.Rule('/foo/', endpoint='foo', methods=["POST"]), - ]) - adapter = map.bind('example.org', '/') + map = r.Map( + [ + r.Rule("/bar/", endpoint="get", methods=["GET"]), + r.Rule("/bar", endpoint="post", methods=["POST"]), + r.Rule("/foo/", endpoint="foo", methods=["POST"]), + ] + ) + adapter = map.bind("example.org", "/") # Check if the actual routes works - assert adapter.match('/bar/', method='GET') == ('get', {}) - assert adapter.match('/bar', method='POST') == ('post', {}) + assert adapter.match("/bar/", method="GET") == ("get", {}) + assert adapter.match("/bar", method="POST") == ("post", {}) # Check if exceptions are correct - pytest.raises(r.RequestRedirect, adapter.match, '/bar', method='GET') - pytest.raises(r.MethodNotAllowed, adapter.match, '/bar/', method='POST') + pytest.raises(r.RequestRedirect, adapter.match, "/bar", method="GET") + pytest.raises(r.MethodNotAllowed, adapter.match, "/bar/", method="POST") with pytest.raises(r.RequestRedirect) as error_info: - adapter.match('/foo', method='POST') + adapter.match("/foo", method="POST") assert error_info.value.code == 308 # Check differently defined order - map = r.Map([ - r.Rule('/bar', endpoint='post', methods=["POST"]), - r.Rule('/bar/', endpoint='get', methods=["GET"]), - ]) - adapter = map.bind('example.org', '/') + map = r.Map( + [ + r.Rule("/bar", endpoint="post", methods=["POST"]), + r.Rule("/bar/", endpoint="get", methods=["GET"]), + ] + ) + adapter = map.bind("example.org", "/") # Check if the actual routes works - assert adapter.match('/bar/', method='GET') == ('get', {}) - assert adapter.match('/bar', method='POST') == ('post', {}) + assert adapter.match("/bar/", method="GET") == ("get", {}) + assert adapter.match("/bar", method="POST") == ("post", {}) # Check if exceptions are correct - pytest.raises(r.RequestRedirect, adapter.match, '/bar', method='GET') - pytest.raises(r.MethodNotAllowed, adapter.match, '/bar/', method='POST') + pytest.raises(r.RequestRedirect, adapter.match, "/bar", method="GET") + pytest.raises(r.MethodNotAllowed, adapter.match, "/bar/", method="POST") # Check what happens when only slash route is defined - map = r.Map([ - r.Rule('/bar/', endpoint='get', methods=["GET"]), - ]) - adapter = map.bind('example.org', '/') + map = r.Map([r.Rule("/bar/", endpoint="get", methods=["GET"])]) + adapter = map.bind("example.org", "/") # Check if the actual routes works - assert adapter.match('/bar/', method='GET') == ('get', {}) + assert adapter.match("/bar/", method="GET") == ("get", {}) # Check if exceptions are correct - pytest.raises(r.RequestRedirect, adapter.match, '/bar', method='GET') - pytest.raises(r.MethodNotAllowed, adapter.match, '/bar/', method='POST') - pytest.raises(r.MethodNotAllowed, adapter.match, '/bar', method='POST') + pytest.raises(r.RequestRedirect, adapter.match, "/bar", method="GET") + pytest.raises(r.MethodNotAllowed, adapter.match, "/bar/", method="POST") + pytest.raises(r.MethodNotAllowed, adapter.match, "/bar", method="POST") def test_environ_defaults(): environ = create_environ("/foo") - strict_eq(environ["PATH_INFO"], '/foo') + strict_eq(environ["PATH_INFO"], "/foo") m = r.Map([r.Rule("/foo", endpoint="foo"), r.Rule("/bar", endpoint="bar")]) a = m.bind_to_environ(environ) - strict_eq(a.match("/foo"), ('foo', {})) - strict_eq(a.match(), ('foo', {})) - strict_eq(a.match("/bar"), ('bar', {})) + strict_eq(a.match("/foo"), ("foo", {})) + strict_eq(a.match(), ("foo", {})) + strict_eq(a.match("/bar"), ("bar", {})) pytest.raises(r.NotFound, a.match, "/bars") def test_environ_nonascii_pathinfo(): - environ = create_environ(u'/лошадь') - m = r.Map([ - r.Rule(u'/', endpoint='index'), - r.Rule(u'/лошадь', endpoint='horse') - ]) + environ = create_environ(u"/лошадь") + m = r.Map([r.Rule(u"/", endpoint="index"), r.Rule(u"/лошадь", endpoint="horse")]) a = m.bind_to_environ(environ) - strict_eq(a.match(u'/'), ('index', {})) - strict_eq(a.match(u'/лошадь'), ('horse', {})) - pytest.raises(r.NotFound, a.match, u'/барсук') + strict_eq(a.match(u"/"), ("index", {})) + strict_eq(a.match(u"/лошадь"), ("horse", {})) + pytest.raises(r.NotFound, a.match, u"/барсук") def test_basic_building(): - map = r.Map([ - r.Rule('/', endpoint='index'), - r.Rule('/foo', endpoint='foo'), - r.Rule('/bar/<baz>', endpoint='bar'), - r.Rule('/bar/<int:bazi>', endpoint='bari'), - r.Rule('/bar/<float:bazf>', endpoint='barf'), - r.Rule('/bar/<path:bazp>', endpoint='barp'), - r.Rule('/hehe', endpoint='blah', subdomain='blah') - ]) - adapter = map.bind('example.org', '/', subdomain='blah') - - assert adapter.build('index', {}) == 'http://example.org/' - assert adapter.build('foo', {}) == 'http://example.org/foo' - assert adapter.build('bar', {'baz': 'blub'}) == \ - 'http://example.org/bar/blub' - assert adapter.build('bari', {'bazi': 50}) == 'http://example.org/bar/50' - assert adapter.build('barf', {'bazf': 0.815}) == \ - 'http://example.org/bar/0.815' - assert adapter.build('barp', {'bazp': 'la/di'}) == \ - 'http://example.org/bar/la/di' - assert adapter.build('blah', {}) == '/hehe' - pytest.raises(r.BuildError, lambda: adapter.build('urks')) - - adapter = map.bind('example.org', '/test', subdomain='blah') - assert adapter.build('index', {}) == 'http://example.org/test/' - assert adapter.build('foo', {}) == 'http://example.org/test/foo' - assert adapter.build('bar', {'baz': 'blub'}) == \ - 'http://example.org/test/bar/blub' - assert adapter.build('bari', {'bazi': 50}) == 'http://example.org/test/bar/50' - assert adapter.build('barf', {'bazf': 0.815}) == 'http://example.org/test/bar/0.815' - assert adapter.build('barp', {'bazp': 'la/di'}) == 'http://example.org/test/bar/la/di' - assert adapter.build('blah', {}) == '/test/hehe' - - adapter = map.bind('example.org') - assert adapter.build('foo', {}) == '/foo' - assert adapter.build('foo', {}, force_external=True) == 'http://example.org/foo' - adapter = map.bind('example.org', url_scheme='') - assert adapter.build('foo', {}) == '/foo' - assert adapter.build('foo', {}, force_external=True) == '//example.org/foo' + map = r.Map( + [ + r.Rule("/", endpoint="index"), + r.Rule("/foo", endpoint="foo"), + r.Rule("/bar/<baz>", endpoint="bar"), + r.Rule("/bar/<int:bazi>", endpoint="bari"), + r.Rule("/bar/<float:bazf>", endpoint="barf"), + r.Rule("/bar/<path:bazp>", endpoint="barp"), + r.Rule("/hehe", endpoint="blah", subdomain="blah"), + ] + ) + adapter = map.bind("example.org", "/", subdomain="blah") + + assert adapter.build("index", {}) == "http://example.org/" + assert adapter.build("foo", {}) == "http://example.org/foo" + assert adapter.build("bar", {"baz": "blub"}) == "http://example.org/bar/blub" + assert adapter.build("bari", {"bazi": 50}) == "http://example.org/bar/50" + assert adapter.build("barf", {"bazf": 0.815}) == "http://example.org/bar/0.815" + assert adapter.build("barp", {"bazp": "la/di"}) == "http://example.org/bar/la/di" + assert adapter.build("blah", {}) == "/hehe" + pytest.raises(r.BuildError, lambda: adapter.build("urks")) + + adapter = map.bind("example.org", "/test", subdomain="blah") + assert adapter.build("index", {}) == "http://example.org/test/" + assert adapter.build("foo", {}) == "http://example.org/test/foo" + assert adapter.build("bar", {"baz": "blub"}) == "http://example.org/test/bar/blub" + assert adapter.build("bari", {"bazi": 50}) == "http://example.org/test/bar/50" + assert adapter.build("barf", {"bazf": 0.815}) == "http://example.org/test/bar/0.815" + assert ( + adapter.build("barp", {"bazp": "la/di"}) == "http://example.org/test/bar/la/di" + ) + assert adapter.build("blah", {}) == "/test/hehe" + + adapter = map.bind("example.org") + assert adapter.build("foo", {}) == "/foo" + assert adapter.build("foo", {}, force_external=True) == "http://example.org/foo" + adapter = map.bind("example.org", url_scheme="") + assert adapter.build("foo", {}) == "/foo" + assert adapter.build("foo", {}, force_external=True) == "//example.org/foo" def test_long_build(): - long_args = dict(('v%d' % x, x) for x in range(10000)) - map = r.Map([ - r.Rule(''.join('/<%s>' % k for k in long_args.keys()), endpoint='bleep', build_only=True) - ]) - adapter = map.bind('localhost', '/') - url = adapter.build('bleep', long_args) - url += '/' + long_args = dict(("v%d" % x, x) for x in range(10000)) + map = r.Map( + [ + r.Rule( + "".join("/<%s>" % k for k in long_args.keys()), + endpoint="bleep", + build_only=True, + ) + ] + ) + adapter = map.bind("localhost", "/") + url = adapter.build("bleep", long_args) + url += "/" for v in long_args.values(): - assert '/%d' % v in url + assert "/%d" % v in url def test_defaults(): - map = r.Map([ - r.Rule('/foo/', defaults={'page': 1}, endpoint='foo'), - r.Rule('/foo/<int:page>', endpoint='foo') - ]) - adapter = map.bind('example.org', '/') + map = r.Map( + [ + r.Rule("/foo/", defaults={"page": 1}, endpoint="foo"), + r.Rule("/foo/<int:page>", endpoint="foo"), + ] + ) + adapter = map.bind("example.org", "/") - assert adapter.match('/foo/') == ('foo', {'page': 1}) - pytest.raises(r.RequestRedirect, lambda: adapter.match('/foo/1')) - assert adapter.match('/foo/2') == ('foo', {'page': 2}) - assert adapter.build('foo', {}) == '/foo/' - assert adapter.build('foo', {'page': 1}) == '/foo/' - assert adapter.build('foo', {'page': 2}) == '/foo/2' + assert adapter.match("/foo/") == ("foo", {"page": 1}) + pytest.raises(r.RequestRedirect, lambda: adapter.match("/foo/1")) + assert adapter.match("/foo/2") == ("foo", {"page": 2}) + assert adapter.build("foo", {}) == "/foo/" + assert adapter.build("foo", {"page": 1}) == "/foo/" + assert adapter.build("foo", {"page": 2}) == "/foo/2" def test_negative(): - map = r.Map([ - r.Rule('/foos/<int(signed=True):page>', endpoint='foos'), - r.Rule('/bars/<float(signed=True):page>', endpoint='bars'), - r.Rule('/foo/<int:page>', endpoint='foo'), - r.Rule('/bar/<float:page>', endpoint='bar') - ]) - adapter = map.bind('example.org', '/') - - assert adapter.match('/foos/-2') == ('foos', {'page': -2}) - assert adapter.match('/foos/-50') == ('foos', {'page': -50}) - assert adapter.match('/bars/-2.0') == ('bars', {'page': -2.0}) - assert adapter.match('/bars/-0.185') == ('bars', {'page': -0.185}) + map = r.Map( + [ + r.Rule("/foos/<int(signed=True):page>", endpoint="foos"), + r.Rule("/bars/<float(signed=True):page>", endpoint="bars"), + r.Rule("/foo/<int:page>", endpoint="foo"), + r.Rule("/bar/<float:page>", endpoint="bar"), + ] + ) + adapter = map.bind("example.org", "/") + + assert adapter.match("/foos/-2") == ("foos", {"page": -2}) + assert adapter.match("/foos/-50") == ("foos", {"page": -50}) + assert adapter.match("/bars/-2.0") == ("bars", {"page": -2.0}) + assert adapter.match("/bars/-0.185") == ("bars", {"page": -0.185}) # Make sure signed values are rejected in unsigned mode - pytest.raises(r.NotFound, lambda: adapter.match('/foo/-2')) - pytest.raises(r.NotFound, lambda: adapter.match('/foo/-50')) - pytest.raises(r.NotFound, lambda: adapter.match('/bar/-0.185')) - pytest.raises(r.NotFound, lambda: adapter.match('/bar/-2.0')) + pytest.raises(r.NotFound, lambda: adapter.match("/foo/-2")) + pytest.raises(r.NotFound, lambda: adapter.match("/foo/-50")) + pytest.raises(r.NotFound, lambda: adapter.match("/bar/-0.185")) + pytest.raises(r.NotFound, lambda: adapter.match("/bar/-2.0")) def test_greedy(): - map = r.Map([ - r.Rule('/foo', endpoint='foo'), - r.Rule('/<path:bar>', endpoint='bar'), - r.Rule('/<path:bar>/<path:blub>', endpoint='bar') - ]) - adapter = map.bind('example.org', '/') + map = r.Map( + [ + r.Rule("/foo", endpoint="foo"), + r.Rule("/<path:bar>", endpoint="bar"), + r.Rule("/<path:bar>/<path:blub>", endpoint="bar"), + ] + ) + adapter = map.bind("example.org", "/") - assert adapter.match('/foo') == ('foo', {}) - assert adapter.match('/blub') == ('bar', {'bar': 'blub'}) - assert adapter.match('/he/he') == ('bar', {'bar': 'he', 'blub': 'he'}) + assert adapter.match("/foo") == ("foo", {}) + assert adapter.match("/blub") == ("bar", {"bar": "blub"}) + assert adapter.match("/he/he") == ("bar", {"bar": "he", "blub": "he"}) - assert adapter.build('foo', {}) == '/foo' - assert adapter.build('bar', {'bar': 'blub'}) == '/blub' - assert adapter.build('bar', {'bar': 'blub', 'blub': 'bar'}) == '/blub/bar' + assert adapter.build("foo", {}) == "/foo" + assert adapter.build("bar", {"bar": "blub"}) == "/blub" + assert adapter.build("bar", {"bar": "blub", "blub": "bar"}) == "/blub/bar" def test_path(): - map = r.Map([ - r.Rule('/', defaults={'name': 'FrontPage'}, endpoint='page'), - r.Rule('/Special', endpoint='special'), - r.Rule('/<int:year>', endpoint='year'), - r.Rule('/<path:name>:foo', endpoint='foopage'), - r.Rule('/<path:name>:<path:name2>', endpoint='twopage'), - r.Rule('/<path:name>', endpoint='page'), - r.Rule('/<path:name>/edit', endpoint='editpage'), - r.Rule('/<path:name>/silly/<path:name2>', endpoint='sillypage'), - r.Rule('/<path:name>/silly/<path:name2>/edit', endpoint='editsillypage'), - r.Rule('/Talk:<path:name>', endpoint='talk'), - r.Rule('/User:<username>', endpoint='user'), - r.Rule('/User:<username>/<path:name>', endpoint='userpage'), - r.Rule('/User:<username>/comment/<int:id>-<int:replyId>', endpoint='usercomment'), - r.Rule('/Files/<path:file>', endpoint='files'), - r.Rule('/<admin>/<manage>/<things>', endpoint='admin'), - ]) - adapter = map.bind('example.org', '/') - - assert adapter.match('/') == ('page', {'name': 'FrontPage'}) - pytest.raises(r.RequestRedirect, lambda: adapter.match('/FrontPage')) - assert adapter.match('/Special') == ('special', {}) - assert adapter.match('/2007') == ('year', {'year': 2007}) - assert adapter.match('/Some:foo') == ('foopage', {'name': 'Some'}) - assert adapter.match('/Some:bar') == ('twopage', {'name': 'Some', 'name2': 'bar'}) - assert adapter.match('/Some/Page') == ('page', {'name': 'Some/Page'}) - assert adapter.match('/Some/Page/edit') == ('editpage', {'name': 'Some/Page'}) - assert adapter.match('/Foo/silly/bar') == ('sillypage', {'name': 'Foo', 'name2': 'bar'}) - assert adapter.match( - '/Foo/silly/bar/edit') == ('editsillypage', {'name': 'Foo', 'name2': 'bar'}) - assert adapter.match('/Talk:Foo/Bar') == ('talk', {'name': 'Foo/Bar'}) - assert adapter.match('/User:thomas') == ('user', {'username': 'thomas'}) - assert adapter.match('/User:thomas/projects/werkzeug') == \ - ('userpage', {'username': 'thomas', 'name': 'projects/werkzeug'}) - assert adapter.match('/User:thomas/comment/123-456') == \ - ('usercomment', {'username': 'thomas', 'id': 123, 'replyId': 456}) - assert adapter.match('/Files/downloads/werkzeug/0.2.zip') == \ - ('files', {'file': 'downloads/werkzeug/0.2.zip'}) - assert adapter.match('/Jerry/eats/cheese') == \ - ('admin', {'admin': 'Jerry', 'manage': 'eats', 'things': 'cheese'}) + map = r.Map( + [ + r.Rule("/", defaults={"name": "FrontPage"}, endpoint="page"), + r.Rule("/Special", endpoint="special"), + r.Rule("/<int:year>", endpoint="year"), + r.Rule("/<path:name>:foo", endpoint="foopage"), + r.Rule("/<path:name>:<path:name2>", endpoint="twopage"), + r.Rule("/<path:name>", endpoint="page"), + r.Rule("/<path:name>/edit", endpoint="editpage"), + r.Rule("/<path:name>/silly/<path:name2>", endpoint="sillypage"), + r.Rule("/<path:name>/silly/<path:name2>/edit", endpoint="editsillypage"), + r.Rule("/Talk:<path:name>", endpoint="talk"), + r.Rule("/User:<username>", endpoint="user"), + r.Rule("/User:<username>/<path:name>", endpoint="userpage"), + r.Rule( + "/User:<username>/comment/<int:id>-<int:replyId>", + endpoint="usercomment", + ), + r.Rule("/Files/<path:file>", endpoint="files"), + r.Rule("/<admin>/<manage>/<things>", endpoint="admin"), + ] + ) + adapter = map.bind("example.org", "/") + + assert adapter.match("/") == ("page", {"name": "FrontPage"}) + pytest.raises(r.RequestRedirect, lambda: adapter.match("/FrontPage")) + assert adapter.match("/Special") == ("special", {}) + assert adapter.match("/2007") == ("year", {"year": 2007}) + assert adapter.match("/Some:foo") == ("foopage", {"name": "Some"}) + assert adapter.match("/Some:bar") == ("twopage", {"name": "Some", "name2": "bar"}) + assert adapter.match("/Some/Page") == ("page", {"name": "Some/Page"}) + assert adapter.match("/Some/Page/edit") == ("editpage", {"name": "Some/Page"}) + assert adapter.match("/Foo/silly/bar") == ( + "sillypage", + {"name": "Foo", "name2": "bar"}, + ) + assert adapter.match("/Foo/silly/bar/edit") == ( + "editsillypage", + {"name": "Foo", "name2": "bar"}, + ) + assert adapter.match("/Talk:Foo/Bar") == ("talk", {"name": "Foo/Bar"}) + assert adapter.match("/User:thomas") == ("user", {"username": "thomas"}) + assert adapter.match("/User:thomas/projects/werkzeug") == ( + "userpage", + {"username": "thomas", "name": "projects/werkzeug"}, + ) + assert adapter.match("/User:thomas/comment/123-456") == ( + "usercomment", + {"username": "thomas", "id": 123, "replyId": 456}, + ) + assert adapter.match("/Files/downloads/werkzeug/0.2.zip") == ( + "files", + {"file": "downloads/werkzeug/0.2.zip"}, + ) + assert adapter.match("/Jerry/eats/cheese") == ( + "admin", + {"admin": "Jerry", "manage": "eats", "things": "cheese"}, + ) def test_dispatch(): - env = create_environ('/') - map = r.Map([ - r.Rule('/', endpoint='root'), - r.Rule('/foo/', endpoint='foo') - ]) + env = create_environ("/") + map = r.Map([r.Rule("/", endpoint="root"), r.Rule("/foo/", endpoint="foo")]) adapter = map.bind_to_environ(env) raise_this = None @@ -296,602 +323,625 @@ def test_dispatch(): if raise_this is not None: raise raise_this return Response(repr((endpoint, values))) - dispatch = lambda p, q=False: Response.force_type( - adapter.dispatch(view_func, p, catch_http_exceptions=q), - env - ) - assert dispatch('/').data == b"('root', {})" - assert dispatch('/foo').status_code == 308 + def dispatch(path, quiet=False): + return Response.force_type( + adapter.dispatch(view_func, path, catch_http_exceptions=quiet), env + ) + + assert dispatch("/").data == b"('root', {})" + assert dispatch("/foo").status_code == 308 raise_this = r.NotFound() - pytest.raises(r.NotFound, lambda: dispatch('/bar')) - assert dispatch('/bar', True).status_code == 404 + pytest.raises(r.NotFound, lambda: dispatch("/bar")) + assert dispatch("/bar", True).status_code == 404 def test_http_host_before_server_name(): env = { - 'HTTP_HOST': 'wiki.example.com', - 'SERVER_NAME': 'web0.example.com', - 'SERVER_PORT': '80', - 'SCRIPT_NAME': '', - 'PATH_INFO': '', - 'REQUEST_METHOD': 'GET', - 'wsgi.url_scheme': 'http' + "HTTP_HOST": "wiki.example.com", + "SERVER_NAME": "web0.example.com", + "SERVER_PORT": "80", + "SCRIPT_NAME": "", + "PATH_INFO": "", + "REQUEST_METHOD": "GET", + "wsgi.url_scheme": "http", } - map = r.Map([r.Rule('/', endpoint='index', subdomain='wiki')]) - adapter = map.bind_to_environ(env, server_name='example.com') - assert adapter.match('/') == ('index', {}) - assert adapter.build('index', force_external=True) == 'http://wiki.example.com/' - assert adapter.build('index') == '/' + map = r.Map([r.Rule("/", endpoint="index", subdomain="wiki")]) + adapter = map.bind_to_environ(env, server_name="example.com") + assert adapter.match("/") == ("index", {}) + assert adapter.build("index", force_external=True) == "http://wiki.example.com/" + assert adapter.build("index") == "/" - env['HTTP_HOST'] = 'admin.example.com' - adapter = map.bind_to_environ(env, server_name='example.com') - assert adapter.build('index') == 'http://wiki.example.com/' + env["HTTP_HOST"] = "admin.example.com" + adapter = map.bind_to_environ(env, server_name="example.com") + assert adapter.build("index") == "http://wiki.example.com/" def test_adapter_url_parameter_sorting(): - map = r.Map([r.Rule('/', endpoint='index')], sort_parameters=True, - sort_key=lambda x: x[1]) - adapter = map.bind('localhost', '/') - assert adapter.build('index', {'x': 20, 'y': 10, 'z': 30}, - force_external=True) == 'http://localhost/?y=10&x=20&z=30' + map = r.Map( + [r.Rule("/", endpoint="index")], sort_parameters=True, sort_key=lambda x: x[1] + ) + adapter = map.bind("localhost", "/") + assert ( + adapter.build("index", {"x": 20, "y": 10, "z": 30}, force_external=True) + == "http://localhost/?y=10&x=20&z=30" + ) def test_request_direct_charset_bug(): - map = r.Map([r.Rule(u'/öäü/')]) - adapter = map.bind('localhost', '/') + map = r.Map([r.Rule(u"/öäü/")]) + adapter = map.bind("localhost", "/") with pytest.raises(r.RequestRedirect) as excinfo: - adapter.match(u'/öäü') - assert excinfo.value.new_url == 'http://localhost/%C3%B6%C3%A4%C3%BC/' + adapter.match(u"/öäü") + assert excinfo.value.new_url == "http://localhost/%C3%B6%C3%A4%C3%BC/" def test_request_redirect_default(): - map = r.Map([r.Rule(u'/foo', defaults={'bar': 42}), - r.Rule(u'/foo/<int:bar>')]) - adapter = map.bind('localhost', '/') + map = r.Map([r.Rule(u"/foo", defaults={"bar": 42}), r.Rule(u"/foo/<int:bar>")]) + adapter = map.bind("localhost", "/") with pytest.raises(r.RequestRedirect) as excinfo: - adapter.match(u'/foo/42') - assert excinfo.value.new_url == 'http://localhost/foo' + adapter.match(u"/foo/42") + assert excinfo.value.new_url == "http://localhost/foo" def test_request_redirect_default_subdomain(): - map = r.Map([r.Rule(u'/foo', defaults={'bar': 42}, subdomain='test'), - r.Rule(u'/foo/<int:bar>', subdomain='other')]) - adapter = map.bind('localhost', '/', subdomain='other') + map = r.Map( + [ + r.Rule(u"/foo", defaults={"bar": 42}, subdomain="test"), + r.Rule(u"/foo/<int:bar>", subdomain="other"), + ] + ) + adapter = map.bind("localhost", "/", subdomain="other") with pytest.raises(r.RequestRedirect) as excinfo: - adapter.match(u'/foo/42') - assert excinfo.value.new_url == 'http://test.localhost/foo' + adapter.match(u"/foo/42") + assert excinfo.value.new_url == "http://test.localhost/foo" def test_adapter_match_return_rule(): - rule = r.Rule('/foo/', endpoint='foo') + rule = r.Rule("/foo/", endpoint="foo") map = r.Map([rule]) - adapter = map.bind('localhost', '/') - assert adapter.match('/foo/', return_rule=True) == (rule, {}) + adapter = map.bind("localhost", "/") + assert adapter.match("/foo/", return_rule=True) == (rule, {}) def test_server_name_interpolation(): - server_name = 'example.invalid' - map = r.Map([r.Rule('/', endpoint='index'), - r.Rule('/', endpoint='alt', subdomain='alt')]) + server_name = "example.invalid" + map = r.Map( + [r.Rule("/", endpoint="index"), r.Rule("/", endpoint="alt", subdomain="alt")] + ) - env = create_environ('/', 'http://%s/' % server_name) + env = create_environ("/", "http://%s/" % server_name) adapter = map.bind_to_environ(env, server_name=server_name) - assert adapter.match() == ('index', {}) + assert adapter.match() == ("index", {}) - env = create_environ('/', 'http://alt.%s/' % server_name) + env = create_environ("/", "http://alt.%s/" % server_name) adapter = map.bind_to_environ(env, server_name=server_name) - assert adapter.match() == ('alt', {}) + assert adapter.match() == ("alt", {}) - env = create_environ('/', 'http://%s/' % server_name) - adapter = map.bind_to_environ(env, server_name='foo') - assert adapter.subdomain == '<invalid>' + env = create_environ("/", "http://%s/" % server_name) + adapter = map.bind_to_environ(env, server_name="foo") + assert adapter.subdomain == "<invalid>" def test_rule_emptying(): - rule = r.Rule('/foo', {'meh': 'muh'}, 'x', ['POST'], - False, 'x', True, None) + rule = r.Rule("/foo", {"meh": "muh"}, "x", ["POST"], False, "x", True, None) rule2 = rule.empty() assert rule.__dict__ == rule2.__dict__ - rule.methods.add('GET') + rule.methods.add("GET") assert rule.__dict__ != rule2.__dict__ - rule.methods.discard('GET') - rule.defaults['meh'] = 'aha' + rule.methods.discard("GET") + rule.defaults["meh"] = "aha" assert rule.__dict__ != rule2.__dict__ def test_rule_unhashable(): - rule = r.Rule('/foo', {'meh': 'muh'}, 'x', ['POST'], - False, 'x', True, None) + rule = r.Rule("/foo", {"meh": "muh"}, "x", ["POST"], False, "x", True, None) pytest.raises(TypeError, hash, rule) def test_rule_templates(): - testcase = r.RuleTemplate([ - r.Submount( - '/test/$app', - [r.Rule('/foo/', endpoint='handle_foo'), - r.Rule('/bar/', endpoint='handle_bar'), - r.Rule('/baz/', endpoint='handle_baz')]), - r.EndpointPrefix( - '${app}', - [r.Rule('/${app}-blah', endpoint='bar'), - r.Rule('/${app}-meh', endpoint='baz')]), - r.Subdomain( - '$app', - [r.Rule('/blah', endpoint='x_bar'), - r.Rule('/meh', endpoint='x_baz')]) - ]) + testcase = r.RuleTemplate( + [ + r.Submount( + "/test/$app", + [ + r.Rule("/foo/", endpoint="handle_foo"), + r.Rule("/bar/", endpoint="handle_bar"), + r.Rule("/baz/", endpoint="handle_baz"), + ], + ), + r.EndpointPrefix( + "${app}", + [ + r.Rule("/${app}-blah", endpoint="bar"), + r.Rule("/${app}-meh", endpoint="baz"), + ], + ), + r.Subdomain( + "$app", + [r.Rule("/blah", endpoint="x_bar"), r.Rule("/meh", endpoint="x_baz")], + ), + ] + ) url_map = r.Map( - [testcase(app='test1'), testcase(app='test2'), testcase(app='test3'), testcase(app='test4') - ]) - - out = sorted([(x.rule, x.subdomain, x.endpoint) - for x in url_map.iter_rules()]) - - assert out == ([ - ('/blah', 'test1', 'x_bar'), - ('/blah', 'test2', 'x_bar'), - ('/blah', 'test3', 'x_bar'), - ('/blah', 'test4', 'x_bar'), - ('/meh', 'test1', 'x_baz'), - ('/meh', 'test2', 'x_baz'), - ('/meh', 'test3', 'x_baz'), - ('/meh', 'test4', 'x_baz'), - ('/test/test1/bar/', '', 'handle_bar'), - ('/test/test1/baz/', '', 'handle_baz'), - ('/test/test1/foo/', '', 'handle_foo'), - ('/test/test2/bar/', '', 'handle_bar'), - ('/test/test2/baz/', '', 'handle_baz'), - ('/test/test2/foo/', '', 'handle_foo'), - ('/test/test3/bar/', '', 'handle_bar'), - ('/test/test3/baz/', '', 'handle_baz'), - ('/test/test3/foo/', '', 'handle_foo'), - ('/test/test4/bar/', '', 'handle_bar'), - ('/test/test4/baz/', '', 'handle_baz'), - ('/test/test4/foo/', '', 'handle_foo'), - ('/test1-blah', '', 'test1bar'), - ('/test1-meh', '', 'test1baz'), - ('/test2-blah', '', 'test2bar'), - ('/test2-meh', '', 'test2baz'), - ('/test3-blah', '', 'test3bar'), - ('/test3-meh', '', 'test3baz'), - ('/test4-blah', '', 'test4bar'), - ('/test4-meh', '', 'test4baz') - ]) + [ + testcase(app="test1"), + testcase(app="test2"), + testcase(app="test3"), + testcase(app="test4"), + ] + ) + + out = sorted([(x.rule, x.subdomain, x.endpoint) for x in url_map.iter_rules()]) + + assert out == [ + ("/blah", "test1", "x_bar"), + ("/blah", "test2", "x_bar"), + ("/blah", "test3", "x_bar"), + ("/blah", "test4", "x_bar"), + ("/meh", "test1", "x_baz"), + ("/meh", "test2", "x_baz"), + ("/meh", "test3", "x_baz"), + ("/meh", "test4", "x_baz"), + ("/test/test1/bar/", "", "handle_bar"), + ("/test/test1/baz/", "", "handle_baz"), + ("/test/test1/foo/", "", "handle_foo"), + ("/test/test2/bar/", "", "handle_bar"), + ("/test/test2/baz/", "", "handle_baz"), + ("/test/test2/foo/", "", "handle_foo"), + ("/test/test3/bar/", "", "handle_bar"), + ("/test/test3/baz/", "", "handle_baz"), + ("/test/test3/foo/", "", "handle_foo"), + ("/test/test4/bar/", "", "handle_bar"), + ("/test/test4/baz/", "", "handle_baz"), + ("/test/test4/foo/", "", "handle_foo"), + ("/test1-blah", "", "test1bar"), + ("/test1-meh", "", "test1baz"), + ("/test2-blah", "", "test2bar"), + ("/test2-meh", "", "test2baz"), + ("/test3-blah", "", "test3bar"), + ("/test3-meh", "", "test3baz"), + ("/test4-blah", "", "test4bar"), + ("/test4-meh", "", "test4baz"), + ] def test_non_string_parts(): - m = r.Map([ - r.Rule('/<foo>', endpoint='foo') - ]) - a = m.bind('example.com') - assert a.build('foo', {'foo': 42}) == '/42' + m = r.Map([r.Rule("/<foo>", endpoint="foo")]) + a = m.bind("example.com") + assert a.build("foo", {"foo": 42}) == "/42" def test_complex_routing_rules(): - m = r.Map([ - r.Rule('/', endpoint='index'), - r.Rule('/<int:blub>', endpoint='an_int'), - r.Rule('/<blub>', endpoint='a_string'), - r.Rule('/foo/', endpoint='nested'), - r.Rule('/foobar/', endpoint='nestedbar'), - r.Rule('/foo/<path:testing>/', endpoint='nested_show'), - r.Rule('/foo/<path:testing>/edit', endpoint='nested_edit'), - r.Rule('/users/', endpoint='users', defaults={'page': 1}), - r.Rule('/users/page/<int:page>', endpoint='users'), - r.Rule('/foox', endpoint='foox'), - r.Rule('/<path:bar>/<path:blub>', endpoint='barx_path_path') - ]) - a = m.bind('example.com') - - assert a.match('/') == ('index', {}) - assert a.match('/42') == ('an_int', {'blub': 42}) - assert a.match('/blub') == ('a_string', {'blub': 'blub'}) - assert a.match('/foo/') == ('nested', {}) - assert a.match('/foobar/') == ('nestedbar', {}) - assert a.match('/foo/1/2/3/') == ('nested_show', {'testing': '1/2/3'}) - assert a.match('/foo/1/2/3/edit') == ('nested_edit', {'testing': '1/2/3'}) - assert a.match('/users/') == ('users', {'page': 1}) - assert a.match('/users/page/2') == ('users', {'page': 2}) - assert a.match('/foox') == ('foox', {}) - assert a.match('/1/2/3') == ('barx_path_path', {'bar': '1', 'blub': '2/3'}) - - assert a.build('index') == '/' - assert a.build('an_int', {'blub': 42}) == '/42' - assert a.build('a_string', {'blub': 'test'}) == '/test' - assert a.build('nested') == '/foo/' - assert a.build('nestedbar') == '/foobar/' - assert a.build('nested_show', {'testing': '1/2/3'}) == '/foo/1/2/3/' - assert a.build('nested_edit', {'testing': '1/2/3'}) == '/foo/1/2/3/edit' - assert a.build('users', {'page': 1}) == '/users/' - assert a.build('users', {'page': 2}) == '/users/page/2' - assert a.build('foox') == '/foox' - assert a.build('barx_path_path', {'bar': '1', 'blub': '2/3'}) == '/1/2/3' + m = r.Map( + [ + r.Rule("/", endpoint="index"), + r.Rule("/<int:blub>", endpoint="an_int"), + r.Rule("/<blub>", endpoint="a_string"), + r.Rule("/foo/", endpoint="nested"), + r.Rule("/foobar/", endpoint="nestedbar"), + r.Rule("/foo/<path:testing>/", endpoint="nested_show"), + r.Rule("/foo/<path:testing>/edit", endpoint="nested_edit"), + r.Rule("/users/", endpoint="users", defaults={"page": 1}), + r.Rule("/users/page/<int:page>", endpoint="users"), + r.Rule("/foox", endpoint="foox"), + r.Rule("/<path:bar>/<path:blub>", endpoint="barx_path_path"), + ] + ) + a = m.bind("example.com") + + assert a.match("/") == ("index", {}) + assert a.match("/42") == ("an_int", {"blub": 42}) + assert a.match("/blub") == ("a_string", {"blub": "blub"}) + assert a.match("/foo/") == ("nested", {}) + assert a.match("/foobar/") == ("nestedbar", {}) + assert a.match("/foo/1/2/3/") == ("nested_show", {"testing": "1/2/3"}) + assert a.match("/foo/1/2/3/edit") == ("nested_edit", {"testing": "1/2/3"}) + assert a.match("/users/") == ("users", {"page": 1}) + assert a.match("/users/page/2") == ("users", {"page": 2}) + assert a.match("/foox") == ("foox", {}) + assert a.match("/1/2/3") == ("barx_path_path", {"bar": "1", "blub": "2/3"}) + + assert a.build("index") == "/" + assert a.build("an_int", {"blub": 42}) == "/42" + assert a.build("a_string", {"blub": "test"}) == "/test" + assert a.build("nested") == "/foo/" + assert a.build("nestedbar") == "/foobar/" + assert a.build("nested_show", {"testing": "1/2/3"}) == "/foo/1/2/3/" + assert a.build("nested_edit", {"testing": "1/2/3"}) == "/foo/1/2/3/edit" + assert a.build("users", {"page": 1}) == "/users/" + assert a.build("users", {"page": 2}) == "/users/page/2" + assert a.build("foox") == "/foox" + assert a.build("barx_path_path", {"bar": "1", "blub": "2/3"}) == "/1/2/3" def test_default_converters(): class MyMap(r.Map): default_converters = r.Map.default_converters.copy() - default_converters['foo'] = r.UnicodeConverter + default_converters["foo"] = r.UnicodeConverter + assert isinstance(r.Map.default_converters, ImmutableDict) - m = MyMap([ - r.Rule('/a/<foo:a>', endpoint='a'), - r.Rule('/b/<foo:b>', endpoint='b'), - r.Rule('/c/<c>', endpoint='c') - ], converters={'bar': r.UnicodeConverter}) - a = m.bind('example.org', '/') - assert a.match('/a/1') == ('a', {'a': '1'}) - assert a.match('/b/2') == ('b', {'b': '2'}) - assert a.match('/c/3') == ('c', {'c': '3'}) - assert 'foo' not in r.Map.default_converters + m = MyMap( + [ + r.Rule("/a/<foo:a>", endpoint="a"), + r.Rule("/b/<foo:b>", endpoint="b"), + r.Rule("/c/<c>", endpoint="c"), + ], + converters={"bar": r.UnicodeConverter}, + ) + a = m.bind("example.org", "/") + assert a.match("/a/1") == ("a", {"a": "1"}) + assert a.match("/b/2") == ("b", {"b": "2"}) + assert a.match("/c/3") == ("c", {"c": "3"}) + assert "foo" not in r.Map.default_converters def test_uuid_converter(): - m = r.Map([ - r.Rule('/a/<uuid:a_uuid>', endpoint='a') - ]) - a = m.bind('example.org', '/') - rooute, kwargs = a.match('/a/a8098c1a-f86e-11da-bd1a-00112444be1e') - assert type(kwargs['a_uuid']) == uuid.UUID + m = r.Map([r.Rule("/a/<uuid:a_uuid>", endpoint="a")]) + a = m.bind("example.org", "/") + rooute, kwargs = a.match("/a/a8098c1a-f86e-11da-bd1a-00112444be1e") + assert type(kwargs["a_uuid"]) == uuid.UUID def test_converter_with_tuples(): - ''' + """ Regression test for https://github.com/pallets/werkzeug/issues/709 - ''' - class TwoValueConverter(r.BaseConverter): + """ + class TwoValueConverter(r.BaseConverter): def __init__(self, *args, **kwargs): super(TwoValueConverter, self).__init__(*args, **kwargs) - self.regex = r'(\w\w+)/(\w\w+)' + self.regex = r"(\w\w+)/(\w\w+)" def to_python(self, two_values): - one, two = two_values.split('/') + one, two = two_values.split("/") return one, two def to_url(self, values): return "%s/%s" % (values[0], values[1]) - map = r.Map([ - r.Rule('/<two:foo>/', endpoint='handler') - ], converters={'two': TwoValueConverter}) - a = map.bind('example.org', '/') - route, kwargs = a.match('/qwert/yuiop/') - assert kwargs['foo'] == ('qwert', 'yuiop') + map = r.Map( + [r.Rule("/<two:foo>/", endpoint="handler")], + converters={"two": TwoValueConverter}, + ) + a = map.bind("example.org", "/") + route, kwargs = a.match("/qwert/yuiop/") + assert kwargs["foo"] == ("qwert", "yuiop") def test_anyconverter(): - m = r.Map([ - r.Rule('/<any(a1, a2):a>', endpoint='no_dot'), - r.Rule('/<any(a.1, a.2):a>', endpoint='yes_dot') - ]) - a = m.bind('example.org', '/') - assert a.match('/a1') == ('no_dot', {'a': 'a1'}) - assert a.match('/a2') == ('no_dot', {'a': 'a2'}) - assert a.match('/a.1') == ('yes_dot', {'a': 'a.1'}) - assert a.match('/a.2') == ('yes_dot', {'a': 'a.2'}) + m = r.Map( + [ + r.Rule("/<any(a1, a2):a>", endpoint="no_dot"), + r.Rule("/<any(a.1, a.2):a>", endpoint="yes_dot"), + ] + ) + a = m.bind("example.org", "/") + assert a.match("/a1") == ("no_dot", {"a": "a1"}) + assert a.match("/a2") == ("no_dot", {"a": "a2"}) + assert a.match("/a.1") == ("yes_dot", {"a": "a.1"}) + assert a.match("/a.2") == ("yes_dot", {"a": "a.2"}) def test_build_append_unknown(): - map = r.Map([ - r.Rule('/bar/<float:bazf>', endpoint='barf') - ]) - adapter = map.bind('example.org', '/', subdomain='blah') - assert adapter.build('barf', {'bazf': 0.815, 'bif': 1.0}) == \ - 'http://example.org/bar/0.815?bif=1.0' - assert adapter.build('barf', {'bazf': 0.815, 'bif': 1.0}, - append_unknown=False) == 'http://example.org/bar/0.815' + map = r.Map([r.Rule("/bar/<float:bazf>", endpoint="barf")]) + adapter = map.bind("example.org", "/", subdomain="blah") + assert ( + adapter.build("barf", {"bazf": 0.815, "bif": 1.0}) + == "http://example.org/bar/0.815?bif=1.0" + ) + assert ( + adapter.build("barf", {"bazf": 0.815, "bif": 1.0}, append_unknown=False) + == "http://example.org/bar/0.815" + ) def test_build_append_multiple(): - map = r.Map([ - r.Rule('/bar/<float:foo>', endpoint='endp') - ]) - adapter = map.bind('example.org', '/', subdomain='subd') - params = {'foo': 0.815, 'x': [1.0, 3.0], 'y': 2.0} - a, b = adapter.build('endp', params).split('?') - assert a == 'http://example.org/bar/0.815' - assert set(b.split('&')) == set('y=2.0&x=1.0&x=3.0'.split('&')) + map = r.Map([r.Rule("/bar/<float:foo>", endpoint="endp")]) + adapter = map.bind("example.org", "/", subdomain="subd") + params = {"foo": 0.815, "x": [1.0, 3.0], "y": 2.0} + a, b = adapter.build("endp", params).split("?") + assert a == "http://example.org/bar/0.815" + assert set(b.split("&")) == set("y=2.0&x=1.0&x=3.0".split("&")) def test_build_append_multidict(): - map = r.Map([ - r.Rule('/bar/<float:foo>', endpoint='endp') - ]) - adapter = map.bind('example.org', '/', subdomain='subd') - params = MultiDict( - (('foo', 0.815), ('x', 1.0), ('x', 3.0), ('y', 2.0))) - a, b = adapter.build('endp', params).split('?') - assert a == 'http://example.org/bar/0.815' - assert set(b.split('&')) == set('y=2.0&x=1.0&x=3.0'.split('&')) + map = r.Map([r.Rule("/bar/<float:foo>", endpoint="endp")]) + adapter = map.bind("example.org", "/", subdomain="subd") + params = MultiDict((("foo", 0.815), ("x", 1.0), ("x", 3.0), ("y", 2.0))) + a, b = adapter.build("endp", params).split("?") + assert a == "http://example.org/bar/0.815" + assert set(b.split("&")) == set("y=2.0&x=1.0&x=3.0".split("&")) def test_build_drop_none(): - map = r.Map([ - r.Rule('/flob/<flub>', endpoint='endp') - ]) - adapter = map.bind('', '/') - params = {'flub': None, 'flop': None} + map = r.Map([r.Rule("/flob/<flub>", endpoint="endp")]) + adapter = map.bind("", "/") + params = {"flub": None, "flop": None} with pytest.raises(r.BuildError): - x = adapter.build('endp', params) + x = adapter.build("endp", params) assert not x - params = {'flub': 'x', 'flop': None} - url = adapter.build('endp', params) - assert 'flop' not in url + params = {"flub": "x", "flop": None} + url = adapter.build("endp", params) + assert "flop" not in url def test_method_fallback(): - map = r.Map([ - r.Rule('/', endpoint='index', methods=['GET']), - r.Rule('/<name>', endpoint='hello_name', methods=['GET']), - r.Rule('/select', endpoint='hello_select', methods=['POST']), - r.Rule('/search_get', endpoint='search', methods=['GET']), - r.Rule('/search_post', endpoint='search', methods=['POST']) - ]) - adapter = map.bind('example.com') - assert adapter.build('index') == '/' - assert adapter.build('index', method='GET') == '/' - assert adapter.build('hello_name', {'name': 'foo'}) == '/foo' - assert adapter.build('hello_select') == '/select' - assert adapter.build('hello_select', method='POST') == '/select' - assert adapter.build('search') == '/search_get' - assert adapter.build('search', method='GET') == '/search_get' - assert adapter.build('search', method='POST') == '/search_post' + map = r.Map( + [ + r.Rule("/", endpoint="index", methods=["GET"]), + r.Rule("/<name>", endpoint="hello_name", methods=["GET"]), + r.Rule("/select", endpoint="hello_select", methods=["POST"]), + r.Rule("/search_get", endpoint="search", methods=["GET"]), + r.Rule("/search_post", endpoint="search", methods=["POST"]), + ] + ) + adapter = map.bind("example.com") + assert adapter.build("index") == "/" + assert adapter.build("index", method="GET") == "/" + assert adapter.build("hello_name", {"name": "foo"}) == "/foo" + assert adapter.build("hello_select") == "/select" + assert adapter.build("hello_select", method="POST") == "/select" + assert adapter.build("search") == "/search_get" + assert adapter.build("search", method="GET") == "/search_get" + assert adapter.build("search", method="POST") == "/search_post" def test_implicit_head(): - url_map = r.Map([ - r.Rule('/get', methods=['GET'], endpoint='a'), - r.Rule('/post', methods=['POST'], endpoint='b') - ]) - adapter = url_map.bind('example.org') - assert adapter.match('/get', method='HEAD') == ('a', {}) - pytest.raises(r.MethodNotAllowed, adapter.match, - '/post', method='HEAD') + url_map = r.Map( + [ + r.Rule("/get", methods=["GET"], endpoint="a"), + r.Rule("/post", methods=["POST"], endpoint="b"), + ] + ) + adapter = url_map.bind("example.org") + assert adapter.match("/get", method="HEAD") == ("a", {}) + pytest.raises(r.MethodNotAllowed, adapter.match, "/post", method="HEAD") def test_pass_str_as_router_methods(): with pytest.raises(TypeError): - r.Rule('/get', methods='GET') + r.Rule("/get", methods="GET") def test_protocol_joining_bug(): - m = r.Map([r.Rule('/<foo>', endpoint='x')]) - a = m.bind('example.org') - assert a.build('x', {'foo': 'x:y'}) == '/x:y' - assert a.build('x', {'foo': 'x:y'}, force_external=True) == \ - 'http://example.org/x:y' + m = r.Map([r.Rule("/<foo>", endpoint="x")]) + a = m.bind("example.org") + assert a.build("x", {"foo": "x:y"}) == "/x:y" + assert a.build("x", {"foo": "x:y"}, force_external=True) == "http://example.org/x:y" def test_allowed_methods_querying(): - m = r.Map([r.Rule('/<foo>', methods=['GET', 'HEAD']), - r.Rule('/foo', methods=['POST'])]) - a = m.bind('example.org') - assert sorted(a.allowed_methods('/foo')) == ['GET', 'HEAD', 'POST'] + m = r.Map( + [r.Rule("/<foo>", methods=["GET", "HEAD"]), r.Rule("/foo", methods=["POST"])] + ) + a = m.bind("example.org") + assert sorted(a.allowed_methods("/foo")) == ["GET", "HEAD", "POST"] def test_external_building_with_port(): - map = r.Map([ - r.Rule('/', endpoint='index'), - ]) - adapter = map.bind('example.org:5000', '/') - built_url = adapter.build('index', {}, force_external=True) - assert built_url == 'http://example.org:5000/', built_url + map = r.Map([r.Rule("/", endpoint="index")]) + adapter = map.bind("example.org:5000", "/") + built_url = adapter.build("index", {}, force_external=True) + assert built_url == "http://example.org:5000/", built_url def test_external_building_with_port_bind_to_environ(): - map = r.Map([ - r.Rule('/', endpoint='index'), - ]) + map = r.Map([r.Rule("/", endpoint="index")]) adapter = map.bind_to_environ( - create_environ('/', 'http://example.org:5000/'), - server_name="example.org:5000" + create_environ("/", "http://example.org:5000/"), server_name="example.org:5000" ) - built_url = adapter.build('index', {}, force_external=True) - assert built_url == 'http://example.org:5000/', built_url + built_url = adapter.build("index", {}, force_external=True) + assert built_url == "http://example.org:5000/", built_url def test_external_building_with_port_bind_to_environ_wrong_servername(): - map = r.Map([ - r.Rule('/', endpoint='index'), - ]) - environ = create_environ('/', 'http://example.org:5000/') + map = r.Map([r.Rule("/", endpoint="index")]) + environ = create_environ("/", "http://example.org:5000/") adapter = map.bind_to_environ(environ, server_name="example.org") - assert adapter.subdomain == '<invalid>' + assert adapter.subdomain == "<invalid>" def test_converter_parser(): - args, kwargs = r.parse_converter_args(u'test, a=1, b=3.0') + args, kwargs = r.parse_converter_args(u"test, a=1, b=3.0") - assert args == ('test',) - assert kwargs == {'a': 1, 'b': 3.0} + assert args == ("test",) + assert kwargs == {"a": 1, "b": 3.0} - args, kwargs = r.parse_converter_args('') + args, kwargs = r.parse_converter_args("") assert not args and not kwargs - args, kwargs = r.parse_converter_args('a, b, c,') - assert args == ('a', 'b', 'c') + args, kwargs = r.parse_converter_args("a, b, c,") + assert args == ("a", "b", "c") assert not kwargs - args, kwargs = r.parse_converter_args('True, False, None') + args, kwargs = r.parse_converter_args("True, False, None") assert args == (True, False, None) args, kwargs = r.parse_converter_args('"foo", u"bar"') - assert args == ('foo', 'bar') + assert args == ("foo", "bar") def test_alias_redirects(): - m = r.Map([ - r.Rule('/', endpoint='index'), - r.Rule('/index.html', endpoint='index', alias=True), - r.Rule('/users/', defaults={'page': 1}, endpoint='users'), - r.Rule('/users/index.html', defaults={'page': 1}, alias=True, - endpoint='users'), - r.Rule('/users/page/<int:page>', endpoint='users'), - r.Rule('/users/page-<int:page>.html', alias=True, endpoint='users'), - ]) - a = m.bind('example.com') + m = r.Map( + [ + r.Rule("/", endpoint="index"), + r.Rule("/index.html", endpoint="index", alias=True), + r.Rule("/users/", defaults={"page": 1}, endpoint="users"), + r.Rule( + "/users/index.html", defaults={"page": 1}, alias=True, endpoint="users" + ), + r.Rule("/users/page/<int:page>", endpoint="users"), + r.Rule("/users/page-<int:page>.html", alias=True, endpoint="users"), + ] + ) + a = m.bind("example.com") def ensure_redirect(path, new_url, args=None): with pytest.raises(r.RequestRedirect) as excinfo: a.match(path, query_args=args) - assert excinfo.value.new_url == 'http://example.com' + new_url + assert excinfo.value.new_url == "http://example.com" + new_url - ensure_redirect('/index.html', '/') - ensure_redirect('/users/index.html', '/users/') - ensure_redirect('/users/page-2.html', '/users/page/2') - ensure_redirect('/users/page-1.html', '/users/') - ensure_redirect('/users/page-1.html', '/users/?foo=bar', {'foo': 'bar'}) + ensure_redirect("/index.html", "/") + ensure_redirect("/users/index.html", "/users/") + ensure_redirect("/users/page-2.html", "/users/page/2") + ensure_redirect("/users/page-1.html", "/users/") + ensure_redirect("/users/page-1.html", "/users/?foo=bar", {"foo": "bar"}) - assert a.build('index') == '/' - assert a.build('users', {'page': 1}) == '/users/' - assert a.build('users', {'page': 2}) == '/users/page/2' + assert a.build("index") == "/" + assert a.build("users", {"page": 1}) == "/users/" + assert a.build("users", {"page": 2}) == "/users/page/2" -@pytest.mark.parametrize('prefix', ('', '/aaa')) +@pytest.mark.parametrize("prefix", ("", "/aaa")) def test_double_defaults(prefix): - m = r.Map([ - r.Rule(prefix + '/', defaults={'foo': 1, 'bar': False}, endpoint='x'), - r.Rule(prefix + '/<int:foo>', defaults={'bar': False}, endpoint='x'), - r.Rule(prefix + '/bar/', defaults={'foo': 1, 'bar': True}, endpoint='x'), - r.Rule(prefix + '/bar/<int:foo>', defaults={'bar': True}, endpoint='x') - ]) - a = m.bind('example.com') - - assert a.match(prefix + '/') == ('x', {'foo': 1, 'bar': False}) - assert a.match(prefix + '/2') == ('x', {'foo': 2, 'bar': False}) - assert a.match(prefix + '/bar/') == ('x', {'foo': 1, 'bar': True}) - assert a.match(prefix + '/bar/2') == ('x', {'foo': 2, 'bar': True}) - - assert a.build('x', {'foo': 1, 'bar': False}) == prefix + '/' - assert a.build('x', {'foo': 2, 'bar': False}) == prefix + '/2' - assert a.build('x', {'bar': False}) == prefix + '/' - assert a.build('x', {'foo': 1, 'bar': True}) == prefix + '/bar/' - assert a.build('x', {'foo': 2, 'bar': True}) == prefix + '/bar/2' - assert a.build('x', {'bar': True}) == prefix + '/bar/' + m = r.Map( + [ + r.Rule(prefix + "/", defaults={"foo": 1, "bar": False}, endpoint="x"), + r.Rule(prefix + "/<int:foo>", defaults={"bar": False}, endpoint="x"), + r.Rule(prefix + "/bar/", defaults={"foo": 1, "bar": True}, endpoint="x"), + r.Rule(prefix + "/bar/<int:foo>", defaults={"bar": True}, endpoint="x"), + ] + ) + a = m.bind("example.com") + + assert a.match(prefix + "/") == ("x", {"foo": 1, "bar": False}) + assert a.match(prefix + "/2") == ("x", {"foo": 2, "bar": False}) + assert a.match(prefix + "/bar/") == ("x", {"foo": 1, "bar": True}) + assert a.match(prefix + "/bar/2") == ("x", {"foo": 2, "bar": True}) + + assert a.build("x", {"foo": 1, "bar": False}) == prefix + "/" + assert a.build("x", {"foo": 2, "bar": False}) == prefix + "/2" + assert a.build("x", {"bar": False}) == prefix + "/" + assert a.build("x", {"foo": 1, "bar": True}) == prefix + "/bar/" + assert a.build("x", {"foo": 2, "bar": True}) == prefix + "/bar/2" + assert a.build("x", {"bar": True}) == prefix + "/bar/" def test_host_matching(): - m = r.Map([ - r.Rule('/', endpoint='index', host='www.<domain>'), - r.Rule('/', endpoint='files', host='files.<domain>'), - r.Rule('/foo/', defaults={'page': 1}, host='www.<domain>', endpoint='x'), - r.Rule('/<int:page>', host='files.<domain>', endpoint='x') - ], host_matching=True) + m = r.Map( + [ + r.Rule("/", endpoint="index", host="www.<domain>"), + r.Rule("/", endpoint="files", host="files.<domain>"), + r.Rule("/foo/", defaults={"page": 1}, host="www.<domain>", endpoint="x"), + r.Rule("/<int:page>", host="files.<domain>", endpoint="x"), + ], + host_matching=True, + ) - a = m.bind('www.example.com') - assert a.match('/') == ('index', {'domain': 'example.com'}) - assert a.match('/foo/') == ('x', {'domain': 'example.com', 'page': 1}) + a = m.bind("www.example.com") + assert a.match("/") == ("index", {"domain": "example.com"}) + assert a.match("/foo/") == ("x", {"domain": "example.com", "page": 1}) with pytest.raises(r.RequestRedirect) as excinfo: - a.match('/foo') - assert excinfo.value.new_url == 'http://www.example.com/foo/' + a.match("/foo") + assert excinfo.value.new_url == "http://www.example.com/foo/" - a = m.bind('files.example.com') - assert a.match('/') == ('files', {'domain': 'example.com'}) - assert a.match('/2') == ('x', {'domain': 'example.com', 'page': 2}) + a = m.bind("files.example.com") + assert a.match("/") == ("files", {"domain": "example.com"}) + assert a.match("/2") == ("x", {"domain": "example.com", "page": 2}) with pytest.raises(r.RequestRedirect) as excinfo: - a.match('/1') - assert excinfo.value.new_url == 'http://www.example.com/foo/' + a.match("/1") + assert excinfo.value.new_url == "http://www.example.com/foo/" def test_host_matching_building(): - m = r.Map([ - r.Rule('/', endpoint='index', host='www.domain.com'), - r.Rule('/', endpoint='foo', host='my.domain.com') - ], host_matching=True) + m = r.Map( + [ + r.Rule("/", endpoint="index", host="www.domain.com"), + r.Rule("/", endpoint="foo", host="my.domain.com"), + ], + host_matching=True, + ) - www = m.bind('www.domain.com') - assert www.match('/') == ('index', {}) - assert www.build('index') == '/' - assert www.build('foo') == 'http://my.domain.com/' + www = m.bind("www.domain.com") + assert www.match("/") == ("index", {}) + assert www.build("index") == "/" + assert www.build("foo") == "http://my.domain.com/" - my = m.bind('my.domain.com') - assert my.match('/') == ('foo', {}) - assert my.build('foo') == '/' - assert my.build('index') == 'http://www.domain.com/' + my = m.bind("my.domain.com") + assert my.match("/") == ("foo", {}) + assert my.build("foo") == "/" + assert my.build("index") == "http://www.domain.com/" def test_server_name_casing(): - m = r.Map([ - r.Rule('/', endpoint='index', subdomain='foo') - ]) + m = r.Map([r.Rule("/", endpoint="index", subdomain="foo")]) env = create_environ() - env['SERVER_NAME'] = env['HTTP_HOST'] = 'FOO.EXAMPLE.COM' - a = m.bind_to_environ(env, server_name='example.com') - assert a.match('/') == ('index', {}) + env["SERVER_NAME"] = env["HTTP_HOST"] = "FOO.EXAMPLE.COM" + a = m.bind_to_environ(env, server_name="example.com") + assert a.match("/") == ("index", {}) env = create_environ() - env['SERVER_NAME'] = '127.0.0.1' - env['SERVER_PORT'] = '5000' - del env['HTTP_HOST'] - a = m.bind_to_environ(env, server_name='example.com') + env["SERVER_NAME"] = "127.0.0.1" + env["SERVER_PORT"] = "5000" + del env["HTTP_HOST"] + a = m.bind_to_environ(env, server_name="example.com") with pytest.raises(r.NotFound): a.match() def test_redirect_request_exception_code(): - exc = r.RequestRedirect('http://www.google.com/') + exc = r.RequestRedirect("http://www.google.com/") exc.code = 307 env = create_environ() strict_eq(exc.get_response(env).status_code, exc.code) def test_redirect_path_quoting(): - url_map = r.Map([ - r.Rule('/<category>', defaults={'page': 1}, endpoint='category'), - r.Rule('/<category>/page/<int:page>', endpoint='category') - ]) - adapter = url_map.bind('example.com') + url_map = r.Map( + [ + r.Rule("/<category>", defaults={"page": 1}, endpoint="category"), + r.Rule("/<category>/page/<int:page>", endpoint="category"), + ] + ) + adapter = url_map.bind("example.com") with pytest.raises(r.RequestRedirect) as excinfo: - adapter.match('/foo bar/page/1') + adapter.match("/foo bar/page/1") response = excinfo.value.get_response({}) - strict_eq(response.headers['location'], - u'http://example.com/foo%20bar') + strict_eq(response.headers["location"], u"http://example.com/foo%20bar") def test_unicode_rules(): - m = r.Map([ - r.Rule(u'/войти/', endpoint='enter'), - r.Rule(u'/foo+bar/', endpoint='foobar') - ]) - a = m.bind(u'☃.example.com') + m = r.Map( + [r.Rule(u"/войти/", endpoint="enter"), r.Rule(u"/foo+bar/", endpoint="foobar")] + ) + a = m.bind(u"☃.example.com") with pytest.raises(r.RequestRedirect) as excinfo: - a.match(u'/войти') - strict_eq(excinfo.value.new_url, - 'http://xn--n3h.example.com/%D0%B2%D0%BE%D0%B9%D1%82%D0%B8/') + a.match(u"/войти") + strict_eq( + excinfo.value.new_url, + "http://xn--n3h.example.com/%D0%B2%D0%BE%D0%B9%D1%82%D0%B8/", + ) - endpoint, values = a.match(u'/войти/') - strict_eq(endpoint, 'enter') + endpoint, values = a.match(u"/войти/") + strict_eq(endpoint, "enter") strict_eq(values, {}) with pytest.raises(r.RequestRedirect) as excinfo: - a.match(u'/foo+bar') - strict_eq(excinfo.value.new_url, 'http://xn--n3h.example.com/foo+bar/') + a.match(u"/foo+bar") + strict_eq(excinfo.value.new_url, "http://xn--n3h.example.com/foo+bar/") - endpoint, values = a.match(u'/foo+bar/') - strict_eq(endpoint, 'foobar') + endpoint, values = a.match(u"/foo+bar/") + strict_eq(endpoint, "foobar") strict_eq(values, {}) - url = a.build('enter', {}, force_external=True) - strict_eq(url, 'http://xn--n3h.example.com/%D0%B2%D0%BE%D0%B9%D1%82%D0%B8/') + url = a.build("enter", {}, force_external=True) + strict_eq(url, "http://xn--n3h.example.com/%D0%B2%D0%BE%D0%B9%D1%82%D0%B8/") - url = a.build('foobar', {}, force_external=True) - strict_eq(url, 'http://xn--n3h.example.com/foo+bar/') + url = a.build("foobar", {}, force_external=True) + strict_eq(url, "http://xn--n3h.example.com/foo+bar/") def test_empty_path_info(): - m = r.Map([ - r.Rule("/", endpoint="index"), - ]) + m = r.Map([r.Rule("/", endpoint="index")]) b = m.bind("example.com", script_name="/approot") with pytest.raises(r.RequestRedirect) as excinfo: @@ -905,29 +955,24 @@ def test_empty_path_info(): def test_both_bind_and_match_path_info_are_none(): - m = r.Map([r.Rule(u'/', endpoint='index')]) - ma = m.bind('example.org') - strict_eq(ma.match(), ('index', {})) + m = r.Map([r.Rule(u"/", endpoint="index")]) + ma = m.bind("example.org") + strict_eq(ma.match(), ("index", {})) def test_map_repr(): - m = r.Map([ - r.Rule(u'/wat', endpoint='enter'), - r.Rule(u'/woop', endpoint='foobar') - ]) + m = r.Map([r.Rule(u"/wat", endpoint="enter"), r.Rule(u"/woop", endpoint="foobar")]) rv = repr(m) - strict_eq(rv, - "Map([<Rule '/woop' -> foobar>, <Rule '/wat' -> enter>])") + strict_eq(rv, "Map([<Rule '/woop' -> foobar>, <Rule '/wat' -> enter>])") def test_empty_subclass_rules_with_custom_kwargs(): class CustomRule(r.Rule): - def __init__(self, string=None, custom=None, *args, **kwargs): self.custom = custom super(CustomRule, self).__init__(string, *args, **kwargs) - rule1 = CustomRule(u'/foo', endpoint='bar') + rule1 = CustomRule(u"/foo", endpoint="bar") try: rule2 = rule1.empty() assert rule1.rule == rule2.rule @@ -936,82 +981,79 @@ def test_empty_subclass_rules_with_custom_kwargs(): def test_finding_closest_match_by_endpoint(): - m = r.Map([ - r.Rule(u'/foo/', endpoint='users.here'), - r.Rule(u'/wat/', endpoint='admin.users'), - r.Rule(u'/woop', endpoint='foo.users'), - ]) - adapter = m.bind('example.com') - assert r.BuildError('admin.user', None, None, adapter).suggested.endpoint \ - == 'admin.users' + m = r.Map( + [ + r.Rule(u"/foo/", endpoint="users.here"), + r.Rule(u"/wat/", endpoint="admin.users"), + r.Rule(u"/woop", endpoint="foo.users"), + ] + ) + adapter = m.bind("example.com") + assert ( + r.BuildError("admin.user", None, None, adapter).suggested.endpoint + == "admin.users" + ) def test_finding_closest_match_by_values(): - rule_id = r.Rule(u'/user/id/<id>/', endpoint='users') - rule_slug = r.Rule(u'/user/<slug>/', endpoint='users') - rule_random = r.Rule(u'/user/emails/<email>/', endpoint='users') + rule_id = r.Rule(u"/user/id/<id>/", endpoint="users") + rule_slug = r.Rule(u"/user/<slug>/", endpoint="users") + rule_random = r.Rule(u"/user/emails/<email>/", endpoint="users") m = r.Map([rule_id, rule_slug, rule_random]) - adapter = m.bind('example.com') - assert r.BuildError('x', {'slug': ''}, None, adapter).suggested == \ - rule_slug + adapter = m.bind("example.com") + assert r.BuildError("x", {"slug": ""}, None, adapter).suggested == rule_slug def test_finding_closest_match_by_method(): - post = r.Rule(u'/post/', endpoint='foobar', methods=['POST']) - get = r.Rule(u'/get/', endpoint='foobar', methods=['GET']) - put = r.Rule(u'/put/', endpoint='foobar', methods=['PUT']) + post = r.Rule(u"/post/", endpoint="foobar", methods=["POST"]) + get = r.Rule(u"/get/", endpoint="foobar", methods=["GET"]) + put = r.Rule(u"/put/", endpoint="foobar", methods=["PUT"]) m = r.Map([post, get, put]) - adapter = m.bind('example.com') - assert r.BuildError('invalid', {}, 'POST', adapter).suggested == post - assert r.BuildError('invalid', {}, 'GET', adapter).suggested == get - assert r.BuildError('invalid', {}, 'PUT', adapter).suggested == put + adapter = m.bind("example.com") + assert r.BuildError("invalid", {}, "POST", adapter).suggested == post + assert r.BuildError("invalid", {}, "GET", adapter).suggested == get + assert r.BuildError("invalid", {}, "PUT", adapter).suggested == put def test_finding_closest_match_when_none_exist(): m = r.Map([]) - assert not r.BuildError('invalid', {}, None, m.bind('test.com')).suggested + assert not r.BuildError("invalid", {}, None, m.bind("test.com")).suggested def test_error_message_without_suggested_rule(): - m = r.Map([ - r.Rule(u'/foo/', endpoint='world', methods=['GET']), - ]) - adapter = m.bind('example.com') + m = r.Map([r.Rule(u"/foo/", endpoint="world", methods=["GET"])]) + adapter = m.bind("example.com") with pytest.raises(r.BuildError) as excinfo: - adapter.build('urks') - assert str(excinfo.value).startswith( - "Could not build url for endpoint 'urks'." - ) + adapter.build("urks") + assert str(excinfo.value).startswith("Could not build url for endpoint 'urks'.") with pytest.raises(r.BuildError) as excinfo: - adapter.build('world', method='POST') + adapter.build("world", method="POST") assert str(excinfo.value).startswith( "Could not build url for endpoint 'world' ('POST')." ) with pytest.raises(r.BuildError) as excinfo: - adapter.build('urks', values={'user_id': 5}) + adapter.build("urks", values={"user_id": 5}) assert str(excinfo.value).startswith( "Could not build url for endpoint 'urks' with values ['user_id']." ) def test_error_message_suggestion(): - m = r.Map([ - r.Rule(u'/foo/<id>/', endpoint='world', methods=['GET']), - ]) - adapter = m.bind('example.com') + m = r.Map([r.Rule(u"/foo/<id>/", endpoint="world", methods=["GET"])]) + adapter = m.bind("example.com") with pytest.raises(r.BuildError) as excinfo: - adapter.build('helloworld') + adapter.build("helloworld") assert "Did you mean 'world' instead?" in str(excinfo.value) with pytest.raises(r.BuildError) as excinfo: - adapter.build('world') + adapter.build("world") assert "Did you forget to specify values ['id']?" in str(excinfo.value) assert "Did you mean to use methods" not in str(excinfo.value) with pytest.raises(r.BuildError) as excinfo: - adapter.build('world', {'id': 2}, method='POST') + adapter.build("world", {"id": 2}, method="POST") assert "Did you mean to use methods ['GET', 'HEAD']?" in str(excinfo.value) diff --git a/tests/test_security.py b/tests/test_security.py index e30a0542..540bba92 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -10,79 +10,85 @@ """ import os import posixpath + import pytest -from werkzeug.security import check_password_hash, generate_password_hash, \ - safe_join, pbkdf2_hex, safe_str_cmp +from werkzeug.security import check_password_hash +from werkzeug.security import generate_password_hash +from werkzeug.security import pbkdf2_hex +from werkzeug.security import safe_join +from werkzeug.security import safe_str_cmp def test_safe_str_cmp(): - assert safe_str_cmp('a', 'a') is True - assert safe_str_cmp(b'a', u'a') is True - assert safe_str_cmp('a', 'b') is False - assert safe_str_cmp(b'aaa', 'aa') is False - assert safe_str_cmp(b'aaa', 'bbb') is False - assert safe_str_cmp(b'aaa', u'aaa') is True - assert safe_str_cmp(u'aaa', u'aaa') is True + assert safe_str_cmp("a", "a") is True + assert safe_str_cmp(b"a", u"a") is True + assert safe_str_cmp("a", "b") is False + assert safe_str_cmp(b"aaa", "aa") is False + assert safe_str_cmp(b"aaa", "bbb") is False + assert safe_str_cmp(b"aaa", u"aaa") is True + assert safe_str_cmp(u"aaa", u"aaa") is True def test_safe_str_cmp_no_builtin(): import werkzeug.security as sec + prev_value = sec._builtin_safe_str_cmp sec._builtin_safe_str_cmp = None - assert safe_str_cmp('a', 'ab') is False + assert safe_str_cmp("a", "ab") is False - assert safe_str_cmp('str', 'str') is True - assert safe_str_cmp('str1', 'str2') is False + assert safe_str_cmp("str", "str") is True + assert safe_str_cmp("str1", "str2") is False sec._builtin_safe_str_cmp = prev_value def test_password_hashing(): - hash0 = generate_password_hash('default') - assert check_password_hash(hash0, 'default') - assert hash0.startswith('pbkdf2:sha256:150000$') + hash0 = generate_password_hash("default") + assert check_password_hash(hash0, "default") + assert hash0.startswith("pbkdf2:sha256:150000$") - hash1 = generate_password_hash('default', 'sha1') - hash2 = generate_password_hash(u'default', method='sha1') + hash1 = generate_password_hash("default", "sha1") + hash2 = generate_password_hash(u"default", method="sha1") assert hash1 != hash2 - assert check_password_hash(hash1, 'default') - assert check_password_hash(hash2, 'default') - assert hash1.startswith('sha1$') - assert hash2.startswith('sha1$') + assert check_password_hash(hash1, "default") + assert check_password_hash(hash2, "default") + assert hash1.startswith("sha1$") + assert hash2.startswith("sha1$") with pytest.raises(ValueError): - check_password_hash('$made$up$', 'default') + check_password_hash("$made$up$", "default") with pytest.raises(ValueError): - generate_password_hash('default', 'sha1', salt_length=0) + generate_password_hash("default", "sha1", salt_length=0) - fakehash = generate_password_hash('default', method='plain') - assert fakehash == 'plain$$default' - assert check_password_hash(fakehash, 'default') + fakehash = generate_password_hash("default", method="plain") + assert fakehash == "plain$$default" + assert check_password_hash(fakehash, "default") - mhash = generate_password_hash(u'default', method='md5') - assert mhash.startswith('md5$') - assert check_password_hash(mhash, 'default') + mhash = generate_password_hash(u"default", method="md5") + assert mhash.startswith("md5$") + assert check_password_hash(mhash, "default") - legacy = 'md5$$c21f969b5f03d33d43e04f8f136e7682' - assert check_password_hash(legacy, 'default') + legacy = "md5$$c21f969b5f03d33d43e04f8f136e7682" + assert check_password_hash(legacy, "default") - legacy = u'md5$$c21f969b5f03d33d43e04f8f136e7682' - assert check_password_hash(legacy, 'default') + legacy = u"md5$$c21f969b5f03d33d43e04f8f136e7682" + assert check_password_hash(legacy, "default") def test_safe_join(): - assert safe_join('foo', 'bar/baz') == posixpath.join('foo', 'bar/baz') - assert safe_join('foo', '../bar/baz') is None - if os.name == 'nt': - assert safe_join('foo', 'foo\\bar') is None + assert safe_join("foo", "bar/baz") == posixpath.join("foo", "bar/baz") + assert safe_join("foo", "../bar/baz") is None + if os.name == "nt": + assert safe_join("foo", "foo\\bar") is None def test_safe_join_os_sep(): import werkzeug.security as sec + prev_value = sec._os_alt_seps - sec._os_alt_seps = '*' - assert safe_join('foo', 'bar/baz*') is None + sec._os_alt_seps = "*" + assert safe_join("foo", "bar/baz*") is None sec._os_alt_steps = prev_value @@ -96,41 +102,107 @@ def test_pbkdf2(): # Assumes default keylen is 20 # check('password', 'salt', 1, None, # '0c60c80f961f0e71f3a9b524af6012062fe037a6') - check('password', 'salt', 1, 20, 'sha1', - '0c60c80f961f0e71f3a9b524af6012062fe037a6') - check('password', 'salt', 2, 20, 'sha1', - 'ea6c014dc72d6f8ccd1ed92ace1d41f0d8de8957') - check('password', 'salt', 4096, 20, 'sha1', - '4b007901b765489abead49d926f721d065a429c1') - check('passwordPASSWORDpassword', 'saltSALTsaltSALTsaltSALTsaltSALTsalt', - 4096, 25, 'sha1', '3d2eec4fe41c849b80c8d83662c0e44a8b291a964cf2f07038') - check('pass\x00word', 'sa\x00lt', 4096, 16, 'sha1', - '56fa6aa75548099dcc37d7f03425e0c3') + check("password", "salt", 1, 20, "sha1", "0c60c80f961f0e71f3a9b524af6012062fe037a6") + check("password", "salt", 2, 20, "sha1", "ea6c014dc72d6f8ccd1ed92ace1d41f0d8de8957") + check( + "password", "salt", 4096, 20, "sha1", "4b007901b765489abead49d926f721d065a429c1" + ) + check( + "passwordPASSWORDpassword", + "saltSALTsaltSALTsaltSALTsaltSALTsalt", + 4096, + 25, + "sha1", + "3d2eec4fe41c849b80c8d83662c0e44a8b291a964cf2f07038", + ) + check( + "pass\x00word", "sa\x00lt", 4096, 16, "sha1", "56fa6aa75548099dcc37d7f03425e0c3" + ) # PBKDF2-HMAC-SHA256 test vectors - check('password', 'salt', 1, 32, 'sha256', - '120fb6cffcf8b32c43e7225256c4f837a86548c92ccc35480805987cb70be17b') - check('password', 'salt', 2, 32, 'sha256', - 'ae4d0c95af6b46d32d0adff928f06dd02a303f8ef3c251dfd6e2d85a95474c43') - check('password', 'salt', 4096, 20, 'sha256', - 'c5e478d59288c841aa530db6845c4c8d962893a0') + check( + "password", + "salt", + 1, + 32, + "sha256", + "120fb6cffcf8b32c43e7225256c4f837a86548c92ccc35480805987cb70be17b", + ) + check( + "password", + "salt", + 2, + 32, + "sha256", + "ae4d0c95af6b46d32d0adff928f06dd02a303f8ef3c251dfd6e2d85a95474c43", + ) + check( + "password", + "salt", + 4096, + 20, + "sha256", + "c5e478d59288c841aa530db6845c4c8d962893a0", + ) # This one is from the RFC but it just takes for ages # check('password', 'salt', 16777216, 20, # 'eefe3d61cd4da4e4e9945b3d6ba2158c2634e984') # From Crypt-PBKDF2 - check('password', 'ATHENA.MIT.EDUraeburn', 1, 16, 'sha1', - 'cdedb5281bb2f801565a1122b2563515') - check('password', 'ATHENA.MIT.EDUraeburn', 1, 32, 'sha1', - 'cdedb5281bb2f801565a1122b25635150ad1f7a04bb9f3a333ecc0e2e1f70837') - check('password', 'ATHENA.MIT.EDUraeburn', 2, 16, 'sha1', - '01dbee7f4a9e243e988b62c73cda935d') - check('password', 'ATHENA.MIT.EDUraeburn', 2, 32, 'sha1', - '01dbee7f4a9e243e988b62c73cda935da05378b93244ec8f48a99e61ad799d86') - check('password', 'ATHENA.MIT.EDUraeburn', 1200, 32, 'sha1', - '5c08eb61fdf71e4e4ec3cf6ba1f5512ba7e52ddbc5e5142f708a31e2e62b1e13') - check('X' * 64, 'pass phrase equals block size', 1200, 32, 'sha1', - '139c30c0966bc32ba55fdbf212530ac9c5ec59f1a452f5cc9ad940fea0598ed1') - check('X' * 65, 'pass phrase exceeds block size', 1200, 32, 'sha1', - '9ccad6d468770cd51b10e6a68721be611a8b4d282601db3b36be9246915ec82a') + check( + "password", + "ATHENA.MIT.EDUraeburn", + 1, + 16, + "sha1", + "cdedb5281bb2f801565a1122b2563515", + ) + check( + "password", + "ATHENA.MIT.EDUraeburn", + 1, + 32, + "sha1", + "cdedb5281bb2f801565a1122b25635150ad1f7a04bb9f3a333ecc0e2e1f70837", + ) + check( + "password", + "ATHENA.MIT.EDUraeburn", + 2, + 16, + "sha1", + "01dbee7f4a9e243e988b62c73cda935d", + ) + check( + "password", + "ATHENA.MIT.EDUraeburn", + 2, + 32, + "sha1", + "01dbee7f4a9e243e988b62c73cda935da05378b93244ec8f48a99e61ad799d86", + ) + check( + "password", + "ATHENA.MIT.EDUraeburn", + 1200, + 32, + "sha1", + "5c08eb61fdf71e4e4ec3cf6ba1f5512ba7e52ddbc5e5142f708a31e2e62b1e13", + ) + check( + "X" * 64, + "pass phrase equals block size", + 1200, + 32, + "sha1", + "139c30c0966bc32ba55fdbf212530ac9c5ec59f1a452f5cc9ad940fea0598ed1", + ) + check( + "X" * 65, + "pass phrase exceeds block size", + 1200, + 32, + "sha1", + "9ccad6d468770cd51b10e6a68721be611a8b4d282601db3b36be9246915ec82a", + ) diff --git a/tests/test_serving.py b/tests/test_serving.py index 186003f3..94a46b7f 100644 --- a/tests/test_serving.py +++ b/tests/test_serving.py @@ -16,6 +16,13 @@ import sys import textwrap import time +import pytest +import requests.exceptions + +from werkzeug import __version__ as version +from werkzeug import _reloader +from werkzeug import serving + try: import OpenSSL except ImportError: @@ -27,112 +34,119 @@ except ImportError: watchdog = None try: - import httplib -except ImportError: from http import client as httplib - -import requests -import requests.exceptions -import pytest - -from werkzeug import __version__ as version, serving, _reloader +except ImportError: + import httplib def test_serving(dev_server): - server = dev_server('from werkzeug.testapp import test_app as app') - rv = requests.get('http://%s/?foo=bar&baz=blah' % server.addr).content - assert b'WSGI Information' in rv - assert b'foo=bar&baz=blah' in rv - assert b'Werkzeug/' + version.encode('ascii') in rv + server = dev_server("from werkzeug.testapp import test_app as app") + rv = requests.get("http://%s/?foo=bar&baz=blah" % server.addr).content + assert b"WSGI Information" in rv + assert b"foo=bar&baz=blah" in rv + assert b"Werkzeug/" + version.encode("ascii") in rv def test_absolute_requests(dev_server): - server = dev_server(''' - def app(environ, start_response): - assert environ['HTTP_HOST'] == 'surelynotexisting.example.com:1337' - assert environ['PATH_INFO'] == '/index.htm' - addr = environ['HTTP_X_WERKZEUG_ADDR'] - assert environ['SERVER_PORT'] == addr.split(':')[1] - start_response('200 OK', [('Content-Type', 'text/html')]) - return [b'YES'] - ''') + server = dev_server( + """ + def app(environ, start_response): + assert environ['HTTP_HOST'] == 'surelynotexisting.example.com:1337' + assert environ['PATH_INFO'] == '/index.htm' + addr = environ['HTTP_X_WERKZEUG_ADDR'] + assert environ['SERVER_PORT'] == addr.split(':')[1] + start_response('200 OK', [('Content-Type', 'text/html')]) + return [b'YES'] + """ + ) conn = httplib.HTTPConnection(server.addr) - conn.request('GET', 'http://surelynotexisting.example.com:1337/index.htm#ignorethis', - headers={'X-Werkzeug-Addr': server.addr}) + conn.request( + "GET", + "http://surelynotexisting.example.com:1337/index.htm#ignorethis", + headers={"X-Werkzeug-Addr": server.addr}, + ) res = conn.getresponse() - assert res.read() == b'YES' + assert res.read() == b"YES" def test_double_slash_path(dev_server): - server = dev_server(''' - def app(environ, start_response): - assert 'fail' not in environ['HTTP_HOST'] - start_response('200 OK', [('Content-Type', 'text/plain')]) - return [b'YES'] - ''') + server = dev_server( + """ + def app(environ, start_response): + assert 'fail' not in environ['HTTP_HOST'] + start_response('200 OK', [('Content-Type', 'text/plain')]) + return [b'YES'] + """ + ) - r = requests.get(server.url + '//fail') - assert r.content == b'YES' + r = requests.get(server.url + "//fail") + assert r.content == b"YES" def test_broken_app(dev_server): - server = dev_server(''' - def app(environ, start_response): - 1 // 0 - ''') - - r = requests.get(server.url + '/?foo=bar&baz=blah') + server = dev_server( + """ + def app(environ, start_response): + 1 // 0 + """ + ) + + r = requests.get(server.url + "/?foo=bar&baz=blah") assert r.status_code == 500 - assert 'Internal Server Error' in r.text + assert "Internal Server Error" in r.text -@pytest.mark.skipif(not hasattr(ssl, 'SSLContext'), - reason='Missing PEP 466 (Python 2.7.9+) or Python 3.') -@pytest.mark.skipif(OpenSSL is None, - reason='OpenSSL is required for cert generation.') +@pytest.mark.skipif( + not hasattr(ssl, "SSLContext"), + reason="Missing PEP 466 (Python 2.7.9+) or Python 3.", +) +@pytest.mark.skipif(OpenSSL is None, reason="OpenSSL is required for cert generation.") def test_stdlib_ssl_contexts(dev_server, tmpdir): - certificate, private_key = \ - serving.make_ssl_devcert(str(tmpdir.mkdir('certs'))) - - server = dev_server(''' - def app(environ, start_response): - start_response('200 OK', [('Content-Type', 'text/html')]) - return [b'hello'] - - import ssl - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - ctx.load_cert_chain(r"%s", r"%s") - kwargs['ssl_context'] = ctx - ''' % (certificate, private_key)) + certificate, private_key = serving.make_ssl_devcert(str(tmpdir.mkdir("certs"))) + + server = dev_server( + """ + def app(environ, start_response): + start_response('200 OK', [('Content-Type', 'text/html')]) + return [b'hello'] + + import ssl + ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx.load_cert_chain(r"%s", r"%s") + kwargs['ssl_context'] = ctx + """ + % (certificate, private_key) + ) assert server.addr is not None r = requests.get(server.url, verify=False) - assert r.content == b'hello' + assert r.content == b"hello" -@pytest.mark.skipif(OpenSSL is None, reason='OpenSSL is not installed.') +@pytest.mark.skipif(OpenSSL is None, reason="OpenSSL is not installed.") def test_ssl_context_adhoc(dev_server): - server = dev_server(''' - def app(environ, start_response): - start_response('200 OK', [('Content-Type', 'text/html')]) - return [b'hello'] - - kwargs['ssl_context'] = 'adhoc' - ''') + server = dev_server( + """ + def app(environ, start_response): + start_response('200 OK', [('Content-Type', 'text/html')]) + return [b'hello'] + + kwargs['ssl_context'] = 'adhoc' + """ + ) r = requests.get(server.url, verify=False) - assert r.content == b'hello' + assert r.content == b"hello" -@pytest.mark.skipif(OpenSSL is None, reason='OpenSSL is not installed.') +@pytest.mark.skipif(OpenSSL is None, reason="OpenSSL is not installed.") def test_make_ssl_devcert(tmpdir): - certificate, private_key = \ - serving.make_ssl_devcert(str(tmpdir)) + certificate, private_key = serving.make_ssl_devcert(str(tmpdir)) assert os.path.isfile(certificate) assert os.path.isfile(private_key) -@pytest.mark.skipif(watchdog is None, reason='Watchdog not installed.') +@pytest.mark.skipif(watchdog is None, reason="Watchdog not installed.") def test_reloader_broken_imports(tmpdir, dev_server): # We explicitly assert that the server reloads on change, even though in # this case the import could've just been retried. This is to assert @@ -142,108 +156,126 @@ def test_reloader_broken_imports(tmpdir, dev_server): # of directories, this only works for the watchdog reloader. The stat # reloader is too inefficient to watch such a large amount of files. - real_app = tmpdir.join('real_app.py') + real_app = tmpdir.join("real_app.py") real_app.write("lol syntax error") - server = dev_server(''' - trials = [] - def app(environ, start_response): - assert not trials, 'should have reloaded' - trials.append(1) - import real_app - return real_app.real_app(environ, start_response) - - kwargs['use_reloader'] = True - kwargs['reloader_interval'] = 0.1 - kwargs['reloader_type'] = 'watchdog' - ''') + server = dev_server( + """ + trials = [] + def app(environ, start_response): + assert not trials, 'should have reloaded' + trials.append(1) + import real_app + return real_app.real_app(environ, start_response) + + kwargs['use_reloader'] = True + kwargs['reloader_interval'] = 0.1 + kwargs['reloader_type'] = 'watchdog' + """ + ) server.wait_for_reloader_loop() r = requests.get(server.url) assert r.status_code == 500 - real_app.write(textwrap.dedent(''' - def real_app(environ, start_response): - start_response('200 OK', [('Content-Type', 'text/html')]) - return [b'hello'] - ''')) + real_app.write( + textwrap.dedent( + """ + def real_app(environ, start_response): + start_response('200 OK', [('Content-Type', 'text/html')]) + return [b'hello'] + """ + ) + ) server.wait_for_reloader() r = requests.get(server.url) assert r.status_code == 200 - assert r.content == b'hello' + assert r.content == b"hello" -@pytest.mark.skipif(watchdog is None, reason='Watchdog not installed.') +@pytest.mark.skipif(watchdog is None, reason="Watchdog not installed.") def test_reloader_nested_broken_imports(tmpdir, dev_server): - real_app = tmpdir.mkdir('real_app') - real_app.join('__init__.py').write('from real_app.sub import real_app') - sub = real_app.mkdir('sub').join('__init__.py') + real_app = tmpdir.mkdir("real_app") + real_app.join("__init__.py").write("from real_app.sub import real_app") + sub = real_app.mkdir("sub").join("__init__.py") sub.write("lol syntax error") - server = dev_server(''' - trials = [] - def app(environ, start_response): - assert not trials, 'should have reloaded' - trials.append(1) - import real_app - return real_app.real_app(environ, start_response) - - kwargs['use_reloader'] = True - kwargs['reloader_interval'] = 0.1 - kwargs['reloader_type'] = 'watchdog' - ''') + server = dev_server( + """ + trials = [] + def app(environ, start_response): + assert not trials, 'should have reloaded' + trials.append(1) + import real_app + return real_app.real_app(environ, start_response) + + kwargs['use_reloader'] = True + kwargs['reloader_interval'] = 0.1 + kwargs['reloader_type'] = 'watchdog' + """ + ) server.wait_for_reloader_loop() r = requests.get(server.url) assert r.status_code == 500 - sub.write(textwrap.dedent(''' - def real_app(environ, start_response): - start_response('200 OK', [('Content-Type', 'text/html')]) - return [b'hello'] - ''')) + sub.write( + textwrap.dedent( + """ + def real_app(environ, start_response): + start_response('200 OK', [('Content-Type', 'text/html')]) + return [b'hello'] + """ + ) + ) server.wait_for_reloader() r = requests.get(server.url) assert r.status_code == 200 - assert r.content == b'hello' + assert r.content == b"hello" -@pytest.mark.skipif(watchdog is None, reason='Watchdog not installed.') +@pytest.mark.skipif(watchdog is None, reason="Watchdog not installed.") def test_reloader_reports_correct_file(tmpdir, dev_server): - real_app = tmpdir.join('real_app.py') - real_app.write(textwrap.dedent(''' - def real_app(environ, start_response): - start_response('200 OK', [('Content-Type', 'text/html')]) - return [b'hello'] - ''')) - - server = dev_server(''' - trials = [] - def app(environ, start_response): - assert not trials, 'should have reloaded' - trials.append(1) - import real_app - return real_app.real_app(environ, start_response) - - kwargs['use_reloader'] = True - kwargs['reloader_interval'] = 0.1 - kwargs['reloader_type'] = 'watchdog' - ''') + real_app = tmpdir.join("real_app.py") + real_app.write( + textwrap.dedent( + """ + def real_app(environ, start_response): + start_response('200 OK', [('Content-Type', 'text/html')]) + return [b'hello'] + """ + ) + ) + + server = dev_server( + """ + trials = [] + def app(environ, start_response): + assert not trials, 'should have reloaded' + trials.append(1) + import real_app + return real_app.real_app(environ, start_response) + + kwargs['use_reloader'] = True + kwargs['reloader_interval'] = 0.1 + kwargs['reloader_type'] = 'watchdog' + """ + ) server.wait_for_reloader_loop() r = requests.get(server.url) assert r.status_code == 200 - assert r.content == b'hello' + assert r.content == b"hello" - real_app_binary = tmpdir.join('real_app.pyc') - real_app_binary.write('anything is fine here') + real_app_binary = tmpdir.join("real_app.pyc") + real_app_binary.write("anything is fine here") server.wait_for_reloader() change_event = " * Detected change in '%(path)s', reloading" % { # need to double escape Windows paths - 'path': str(real_app_binary).replace("\\", "\\\\") + "path": str(real_app_binary).replace("\\", "\\\\") } server.logfile.seek(0) for i in range(20): @@ -252,17 +284,17 @@ def test_reloader_reports_correct_file(tmpdir, dev_server): if change_event in log: break else: - raise RuntimeError('Change event not detected.') + raise RuntimeError("Change event not detected.") def test_windows_get_args_for_reloading(monkeypatch, tmpdir): - test_py_exe = r'C:\Users\test\AppData\Local\Programs\Python\Python36\python.exe' - monkeypatch.setattr(os, 'name', 'nt') - monkeypatch.setattr(sys, 'executable', test_py_exe) - test_exe = tmpdir.mkdir('test').join('test.exe') - monkeypatch.setattr(sys, 'argv', [test_exe.strpath, 'run']) + test_py_exe = r"C:\Users\test\AppData\Local\Programs\Python\Python36\python.exe" + monkeypatch.setattr(os, "name", "nt") + monkeypatch.setattr(sys, "executable", test_py_exe) + test_exe = tmpdir.mkdir("test").join("test.exe") + monkeypatch.setattr(sys, "argv", [test_exe.strpath, "run"]) rv = _reloader._get_args_for_reloading() - assert rv == [test_exe.strpath, 'run'] + assert rv == [test_exe.strpath, "run"] def test_monkeypatched_sleep(tmpdir): @@ -273,20 +305,24 @@ def test_monkeypatched_sleep(tmpdir): # `eventlet.monkey_patch` before importing `_reloader`, `time.sleep` is a # python function, and subsequently calling `ReloaderLoop._sleep` fails # with a TypeError. This test checks that _sleep is attached correctly. - script = tmpdir.mkdir('app').join('test.py') - script.write(textwrap.dedent(''' - import time - - def sleep(secs): - pass - - # simulate eventlet.monkey_patch by replacing the builtin sleep - # with a regular function before _reloader is imported - time.sleep = sleep - - from werkzeug._reloader import ReloaderLoop - ReloaderLoop()._sleep(0) - ''')) + script = tmpdir.mkdir("app").join("test.py") + script.write( + textwrap.dedent( + """ + import time + + def sleep(secs): + pass + + # simulate eventlet.monkey_patch by replacing the builtin sleep + # with a regular function before _reloader is imported + time.sleep = sleep + + from werkzeug._reloader import ReloaderLoop + ReloaderLoop()._sleep(0) + """ + ) + ) subprocess.check_call([sys.executable, str(script)]) @@ -295,185 +331,187 @@ def test_wrong_protocol(dev_server): # traceback # See https://github.com/pallets/werkzeug/pull/838 - server = dev_server(''' - def app(environ, start_response): - start_response('200 OK', [('Content-Type', 'text/html')]) - return [b'hello'] - ''') + server = dev_server( + """ + def app(environ, start_response): + start_response('200 OK', [('Content-Type', 'text/html')]) + return [b'hello'] + """ + ) with pytest.raises(requests.exceptions.ConnectionError): - requests.get('https://%s/' % server.addr) + requests.get("https://%s/" % server.addr) log = server.logfile.read() - assert 'Traceback' not in log - assert '\n127.0.0.1' in log + assert "Traceback" not in log + assert "\n127.0.0.1" in log def test_absent_content_length_and_content_type(dev_server): - server = dev_server(''' - def app(environ, start_response): - assert 'CONTENT_LENGTH' not in environ - assert 'CONTENT_TYPE' not in environ - start_response('200 OK', [('Content-Type', 'text/html')]) - return [b'YES'] - ''') + server = dev_server( + """ + def app(environ, start_response): + assert 'CONTENT_LENGTH' not in environ + assert 'CONTENT_TYPE' not in environ + start_response('200 OK', [('Content-Type', 'text/html')]) + return [b'YES'] + """ + ) r = requests.get(server.url) - assert r.content == b'YES' + assert r.content == b"YES" def test_set_content_length_and_content_type_if_provided_by_client(dev_server): - server = dev_server(''' - def app(environ, start_response): - assert environ['CONTENT_LENGTH'] == '233' - assert environ['CONTENT_TYPE'] == 'application/json' - start_response('200 OK', [('Content-Type', 'text/html')]) - return [b'YES'] - ''') - - r = requests.get(server.url, headers={ - 'content_length': '233', - 'content_type': 'application/json' - }) - assert r.content == b'YES' + server = dev_server( + """ + def app(environ, start_response): + assert environ['CONTENT_LENGTH'] == '233' + assert environ['CONTENT_TYPE'] == 'application/json' + start_response('200 OK', [('Content-Type', 'text/html')]) + return [b'YES'] + """ + ) + + r = requests.get( + server.url, + headers={"content_length": "233", "content_type": "application/json"}, + ) + assert r.content == b"YES" def test_port_must_be_integer(dev_server): def app(environ, start_response): - start_response('200 OK', [('Content-Type', 'text/html')]) - return [b'hello'] + start_response("200 OK", [("Content-Type", "text/html")]) + return [b"hello"] with pytest.raises(TypeError) as excinfo: - serving.run_simple(hostname='localhost', port='5001', - application=app, use_reloader=True) - assert 'port must be an integer' in str(excinfo.value) + serving.run_simple( + hostname="localhost", port="5001", application=app, use_reloader=True + ) + assert "port must be an integer" in str(excinfo.value) with pytest.raises(TypeError) as excinfo: - serving.run_simple(hostname='localhost', port='5001', - application=app, use_reloader=False) - assert 'port must be an integer' in str(excinfo.value) + serving.run_simple( + hostname="localhost", port="5001", application=app, use_reloader=False + ) + assert "port must be an integer" in str(excinfo.value) def test_chunked_encoding(dev_server): - server = dev_server(r''' - from werkzeug.wrappers import Request - def app(environ, start_response): - assert environ['HTTP_TRANSFER_ENCODING'] == 'chunked' - assert environ.get('wsgi.input_terminated', False) - request = Request(environ) - assert request.mimetype == 'multipart/form-data' - assert request.files['file'].read() == b'This is a test\n' - assert request.form['type'] == 'text/plain' - start_response('200 OK', [('Content-Type', 'text/plain')]) - return [b'YES'] - ''') - - testfile = os.path.join(os.path.dirname(__file__), 'res', 'chunked.txt') - - if sys.version_info[0] == 2: - from httplib import HTTPConnection - else: - from http.client import HTTPConnection - - conn = HTTPConnection('127.0.0.1', server.port) + server = dev_server( + r""" + from werkzeug.wrappers import Request + def app(environ, start_response): + assert environ['HTTP_TRANSFER_ENCODING'] == 'chunked' + assert environ.get('wsgi.input_terminated', False) + request = Request(environ) + assert request.mimetype == 'multipart/form-data' + assert request.files['file'].read() == b'This is a test\n' + assert request.form['type'] == 'text/plain' + start_response('200 OK', [('Content-Type', 'text/plain')]) + return [b'YES'] + """ + ) + + testfile = os.path.join(os.path.dirname(__file__), "res", "chunked.http") + + conn = httplib.HTTPConnection("127.0.0.1", server.port) conn.connect() - conn.putrequest('POST', '/', skip_host=1, skip_accept_encoding=1) - conn.putheader('Accept', 'text/plain') - conn.putheader('Transfer-Encoding', 'chunked') + conn.putrequest("POST", "/", skip_host=1, skip_accept_encoding=1) + conn.putheader("Accept", "text/plain") + conn.putheader("Transfer-Encoding", "chunked") conn.putheader( - 'Content-Type', - 'multipart/form-data; boundary=' - '--------------------------898239224156930639461866') + "Content-Type", + "multipart/form-data; boundary=" + "--------------------------898239224156930639461866", + ) conn.endheaders() - with open(testfile, 'rb') as f: + with open(testfile, "rb") as f: conn.send(f.read()) res = conn.getresponse() assert res.status == 200 - assert res.read() == b'YES' + assert res.read() == b"YES" conn.close() def test_chunked_encoding_with_content_length(dev_server): - server = dev_server(r''' - from werkzeug.wrappers import Request - def app(environ, start_response): - assert environ['HTTP_TRANSFER_ENCODING'] == 'chunked' - assert environ.get('wsgi.input_terminated', False) - request = Request(environ) - assert request.mimetype == 'multipart/form-data' - assert request.files['file'].read() == b'This is a test\n' - assert request.form['type'] == 'text/plain' - start_response('200 OK', [('Content-Type', 'text/plain')]) - return [b'YES'] - ''') - - testfile = os.path.join(os.path.dirname(__file__), 'res', 'chunked.txt') - - if sys.version_info[0] == 2: - from httplib import HTTPConnection - else: - from http.client import HTTPConnection - - conn = HTTPConnection('127.0.0.1', server.port) + server = dev_server( + r""" + from werkzeug.wrappers import Request + def app(environ, start_response): + assert environ['HTTP_TRANSFER_ENCODING'] == 'chunked' + assert environ.get('wsgi.input_terminated', False) + request = Request(environ) + assert request.mimetype == 'multipart/form-data' + assert request.files['file'].read() == b'This is a test\n' + assert request.form['type'] == 'text/plain' + start_response('200 OK', [('Content-Type', 'text/plain')]) + return [b'YES'] + """ + ) + + testfile = os.path.join(os.path.dirname(__file__), "res", "chunked.http") + + conn = httplib.HTTPConnection("127.0.0.1", server.port) conn.connect() - conn.putrequest('POST', '/', skip_host=1, skip_accept_encoding=1) - conn.putheader('Accept', 'text/plain') - conn.putheader('Transfer-Encoding', 'chunked') + conn.putrequest("POST", "/", skip_host=1, skip_accept_encoding=1) + conn.putheader("Accept", "text/plain") + conn.putheader("Transfer-Encoding", "chunked") # Content-Length is invalid for chunked, but some libraries might send it - conn.putheader('Content-Length', '372') + conn.putheader("Content-Length", "372") conn.putheader( - 'Content-Type', - 'multipart/form-data; boundary=' - '--------------------------898239224156930639461866') + "Content-Type", + "multipart/form-data; boundary=" + "--------------------------898239224156930639461866", + ) conn.endheaders() - with open(testfile, 'rb') as f: + with open(testfile, "rb") as f: conn.send(f.read()) res = conn.getresponse() assert res.status == 200 - assert res.read() == b'YES' + assert res.read() == b"YES" conn.close() def test_multiple_headers_concatenated_per_rfc_3875_section_4_1_18(dev_server): - server = dev_server(r''' - from werkzeug.wrappers import Response - def app(environ, start_response): - start_response('200 OK', [('Content-Type', 'text/plain')]) - return [environ['HTTP_XYZ'].encode()] - ''') - - if sys.version_info[0] == 2: - from httplib import HTTPConnection - else: - from http.client import HTTPConnection - conn = HTTPConnection('127.0.0.1', server.port) + server = dev_server( + r""" + from werkzeug.wrappers import Response + def app(environ, start_response): + start_response('200 OK', [('Content-Type', 'text/plain')]) + return [environ['HTTP_XYZ'].encode()] + """ + ) + + conn = httplib.HTTPConnection("127.0.0.1", server.port) conn.connect() - conn.putrequest('GET', '/') - conn.putheader('Accept', 'text/plain') - conn.putheader('XYZ', ' a ') - conn.putheader('X-INGNORE-1', 'Some nonsense') - conn.putheader('XYZ', ' b') - conn.putheader('X-INGNORE-2', 'Some nonsense') - conn.putheader('XYZ', 'c ') - conn.putheader('X-INGNORE-3', 'Some nonsense') - conn.putheader('XYZ', 'd') + conn.putrequest("GET", "/") + conn.putheader("Accept", "text/plain") + conn.putheader("XYZ", " a ") + conn.putheader("X-INGNORE-1", "Some nonsense") + conn.putheader("XYZ", " b") + conn.putheader("X-INGNORE-2", "Some nonsense") + conn.putheader("XYZ", "c ") + conn.putheader("X-INGNORE-3", "Some nonsense") + conn.putheader("XYZ", "d") conn.endheaders() - conn.send(b'') + conn.send(b"") res = conn.getresponse() assert res.status == 200 - assert res.read() == b'a ,b,c ,d' + assert res.read() == b"a ,b,c ,d" conn.close() def can_test_unix_socket(): - if not hasattr(socket, 'AF_UNIX'): + if not hasattr(socket, "AF_UNIX"): return False try: import requests_unixsocket # noqa: F401 @@ -482,11 +520,15 @@ def can_test_unix_socket(): return True -@pytest.mark.skipif(not can_test_unix_socket(), reason='Only works on UNIX') +@pytest.mark.skipif(not can_test_unix_socket(), reason="Only works on UNIX") def test_unix_socket(tmpdir, dev_server): - socket_f = str(tmpdir.join('socket')) - dev_server(''' - app = None - kwargs['hostname'] = {socket!r} - '''.format(socket='unix://' + socket_f)) + socket_f = str(tmpdir.join("socket")) + dev_server( + """ + app = None + kwargs['hostname'] = {socket!r} + """.format( + socket="unix://" + socket_f + ) + ) assert os.path.exists(socket_f) diff --git a/tests/test_test.py b/tests/test_test.py index 82c4f3f9..e7f6306b 100644 --- a/tests/test_test.py +++ b/tests/test_test.py @@ -8,168 +8,176 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ - import json -import pytest - import sys -from io import BytesIO -from werkzeug._compat import iteritems, to_bytes, implements_iterator from functools import partial +from io import BytesIO -from tests import strict_eq +import pytest -from werkzeug.wrappers import Request, Response, BaseResponse -from werkzeug.test import Client, EnvironBuilder, create_environ, \ - ClientRedirectError, stream_encode_multipart, run_wsgi_app -from werkzeug.utils import redirect +from . import strict_eq +from werkzeug._compat import implements_iterator +from werkzeug._compat import iteritems +from werkzeug._compat import to_bytes +from werkzeug.datastructures import FileStorage +from werkzeug.datastructures import MultiDict from werkzeug.formparser import parse_form_data -from werkzeug.datastructures import MultiDict, FileStorage +from werkzeug.test import Client +from werkzeug.test import ClientRedirectError +from werkzeug.test import create_environ +from werkzeug.test import EnvironBuilder +from werkzeug.test import run_wsgi_app +from werkzeug.test import stream_encode_multipart +from werkzeug.utils import redirect +from werkzeug.wrappers import BaseResponse +from werkzeug.wrappers import Request +from werkzeug.wrappers import Response def cookie_app(environ, start_response): """A WSGI application which sets a cookie, and returns as a response any cookie which exists. """ - response = Response(environ.get('HTTP_COOKIE', 'No Cookie'), - mimetype='text/plain') - response.set_cookie('test', 'test') + response = Response(environ.get("HTTP_COOKIE", "No Cookie"), mimetype="text/plain") + response.set_cookie("test", "test") return response(environ, start_response) def redirect_loop_app(environ, start_response): - response = redirect('http://localhost/some/redirect/') + response = redirect("http://localhost/some/redirect/") return response(environ, start_response) def redirect_with_get_app(environ, start_response): req = Request(environ) - if req.url not in ('http://localhost/', - 'http://localhost/first/request', - 'http://localhost/some/redirect/'): + if req.url not in ( + "http://localhost/", + "http://localhost/first/request", + "http://localhost/some/redirect/", + ): assert False, 'redirect_demo_app() did not expect URL "%s"' % req.url - if '/some/redirect' not in req.url: - response = redirect('http://localhost/some/redirect/') + if "/some/redirect" not in req.url: + response = redirect("http://localhost/some/redirect/") else: - response = Response('current url: %s' % req.url) + response = Response("current url: %s" % req.url) return response(environ, start_response) def external_redirect_demo_app(environ, start_response): - response = redirect('http://example.com/') + response = redirect("http://example.com/") return response(environ, start_response) def external_subdomain_redirect_demo_app(environ, start_response): - if 'test.example.com' in environ['HTTP_HOST']: - response = Response('redirected successfully to subdomain') + if "test.example.com" in environ["HTTP_HOST"]: + response = Response("redirected successfully to subdomain") else: - response = redirect('http://test.example.com/login') + response = redirect("http://test.example.com/login") return response(environ, start_response) def multi_value_post_app(environ, start_response): req = Request(environ) - assert req.form['field'] == 'val1', req.form['field'] - assert req.form.getlist('field') == ['val1', 'val2'], req.form.getlist('field') - response = Response('ok') + assert req.form["field"] == "val1", req.form["field"] + assert req.form.getlist("field") == ["val1", "val2"], req.form.getlist("field") + response = Response("ok") return response(environ, start_response) def test_cookie_forging(): c = Client(cookie_app) - c.set_cookie('localhost', 'foo', 'bar') + c.set_cookie("localhost", "foo", "bar") appiter, code, headers = c.open() - strict_eq(list(appiter), [b'foo=bar']) + strict_eq(list(appiter), [b"foo=bar"]) def test_set_cookie_app(): c = Client(cookie_app) appiter, code, headers = c.open() - assert 'Set-Cookie' in dict(headers) + assert "Set-Cookie" in dict(headers) def test_cookiejar_stores_cookie(): c = Client(cookie_app) appiter, code, headers = c.open() - assert 'test' in c.cookie_jar._cookies['localhost.local']['/'] + assert "test" in c.cookie_jar._cookies["localhost.local"]["/"] def test_no_initial_cookie(): c = Client(cookie_app) appiter, code, headers = c.open() - strict_eq(b''.join(appiter), b'No Cookie') + strict_eq(b"".join(appiter), b"No Cookie") def test_resent_cookie(): c = Client(cookie_app) c.open() appiter, code, headers = c.open() - strict_eq(b''.join(appiter), b'test=test') + strict_eq(b"".join(appiter), b"test=test") def test_disable_cookies(): c = Client(cookie_app, use_cookies=False) c.open() appiter, code, headers = c.open() - strict_eq(b''.join(appiter), b'No Cookie') + strict_eq(b"".join(appiter), b"No Cookie") def test_cookie_for_different_path(): c = Client(cookie_app) - c.open('/path1') - appiter, code, headers = c.open('/path2') - strict_eq(b''.join(appiter), b'test=test') + c.open("/path1") + appiter, code, headers = c.open("/path2") + strict_eq(b"".join(appiter), b"test=test") def test_environ_builder_basics(): b = EnvironBuilder() assert b.content_type is None - b.method = 'POST' + b.method = "POST" assert b.content_type is None - b.form['test'] = 'normal value' - assert b.content_type == 'application/x-www-form-urlencoded' - b.files.add_file('test', BytesIO(b'test contents'), 'test.txt') - assert b.files['test'].content_type == 'text/plain' - b.form['test_int'] = 1 - assert b.content_type == 'multipart/form-data' + b.form["test"] = "normal value" + assert b.content_type == "application/x-www-form-urlencoded" + b.files.add_file("test", BytesIO(b"test contents"), "test.txt") + assert b.files["test"].content_type == "text/plain" + b.form["test_int"] = 1 + assert b.content_type == "multipart/form-data" req = b.get_request() b.close() - strict_eq(req.url, u'http://localhost/') - strict_eq(req.method, 'POST') - strict_eq(req.form['test'], u'normal value') - assert req.files['test'].content_type == 'text/plain' - strict_eq(req.files['test'].filename, u'test.txt') - strict_eq(req.files['test'].read(), b'test contents') + strict_eq(req.url, u"http://localhost/") + strict_eq(req.method, "POST") + strict_eq(req.form["test"], u"normal value") + assert req.files["test"].content_type == "text/plain" + strict_eq(req.files["test"].filename, u"test.txt") + strict_eq(req.files["test"].read(), b"test contents") def test_environ_builder_data(): - b = EnvironBuilder(data='foo') - assert b.input_stream.getvalue() == b'foo' - b = EnvironBuilder(data=b'foo') - assert b.input_stream.getvalue() == b'foo' + b = EnvironBuilder(data="foo") + assert b.input_stream.getvalue() == b"foo" + b = EnvironBuilder(data=b"foo") + assert b.input_stream.getvalue() == b"foo" - b = EnvironBuilder(data={'foo': 'bar'}) - assert b.form['foo'] == 'bar' - b = EnvironBuilder(data={'foo': ['bar1', 'bar2']}) - assert b.form.getlist('foo') == ['bar1', 'bar2'] + b = EnvironBuilder(data={"foo": "bar"}) + assert b.form["foo"] == "bar" + b = EnvironBuilder(data={"foo": ["bar1", "bar2"]}) + assert b.form.getlist("foo") == ["bar1", "bar2"] def check_list_content(b, length): - foo = b.files.getlist('foo') + foo = b.files.getlist("foo") assert len(foo) == length for obj in foo: assert isinstance(obj, FileStorage) - b = EnvironBuilder(data={'foo': BytesIO()}) + b = EnvironBuilder(data={"foo": BytesIO()}) check_list_content(b, 1) - b = EnvironBuilder(data={'foo': [BytesIO(), BytesIO()]}) + b = EnvironBuilder(data={"foo": [BytesIO(), BytesIO()]}) check_list_content(b, 2) - b = EnvironBuilder(data={'foo': (BytesIO(),)}) + b = EnvironBuilder(data={"foo": (BytesIO(),)}) check_list_content(b, 1) - b = EnvironBuilder(data={'foo': [(BytesIO(),), (BytesIO(),)]}) + b = EnvironBuilder(data={"foo": [(BytesIO(),), (BytesIO(),)]}) check_list_content(b, 2) @@ -188,151 +196,157 @@ def test_environ_builder_json(): def test_environ_builder_headers(): - b = EnvironBuilder(environ_base={'HTTP_USER_AGENT': 'Foo/0.1'}, - environ_overrides={'wsgi.version': (1, 1)}) - b.headers['X-Beat-My-Horse'] = 'very well sir' + b = EnvironBuilder( + environ_base={"HTTP_USER_AGENT": "Foo/0.1"}, + environ_overrides={"wsgi.version": (1, 1)}, + ) + b.headers["X-Beat-My-Horse"] = "very well sir" env = b.get_environ() - strict_eq(env['HTTP_USER_AGENT'], 'Foo/0.1') - strict_eq(env['HTTP_X_BEAT_MY_HORSE'], 'very well sir') - strict_eq(env['wsgi.version'], (1, 1)) + strict_eq(env["HTTP_USER_AGENT"], "Foo/0.1") + strict_eq(env["HTTP_X_BEAT_MY_HORSE"], "very well sir") + strict_eq(env["wsgi.version"], (1, 1)) - b.headers['User-Agent'] = 'Bar/1.0' + b.headers["User-Agent"] = "Bar/1.0" env = b.get_environ() - strict_eq(env['HTTP_USER_AGENT'], 'Bar/1.0') + strict_eq(env["HTTP_USER_AGENT"], "Bar/1.0") def test_environ_builder_headers_content_type(): - b = EnvironBuilder(headers={'Content-Type': 'text/plain'}) + b = EnvironBuilder(headers={"Content-Type": "text/plain"}) env = b.get_environ() - assert env['CONTENT_TYPE'] == 'text/plain' - b = EnvironBuilder(content_type='text/html', - headers={'Content-Type': 'text/plain'}) + assert env["CONTENT_TYPE"] == "text/plain" + b = EnvironBuilder(content_type="text/html", headers={"Content-Type": "text/plain"}) env = b.get_environ() - assert env['CONTENT_TYPE'] == 'text/html' + assert env["CONTENT_TYPE"] == "text/html" b = EnvironBuilder() env = b.get_environ() - assert 'CONTENT_TYPE' not in env + assert "CONTENT_TYPE" not in env def test_environ_builder_paths(): - b = EnvironBuilder(path='/foo', base_url='http://example.com/') - strict_eq(b.base_url, 'http://example.com/') - strict_eq(b.path, '/foo') - strict_eq(b.script_root, '') - strict_eq(b.host, 'example.com') - - b = EnvironBuilder(path='/foo', base_url='http://example.com/bar') - strict_eq(b.base_url, 'http://example.com/bar/') - strict_eq(b.path, '/foo') - strict_eq(b.script_root, '/bar') - strict_eq(b.host, 'example.com') - - b.host = 'localhost' - strict_eq(b.base_url, 'http://localhost/bar/') - b.base_url = 'http://localhost:8080/' - strict_eq(b.host, 'localhost:8080') - strict_eq(b.server_name, 'localhost') + b = EnvironBuilder(path="/foo", base_url="http://example.com/") + strict_eq(b.base_url, "http://example.com/") + strict_eq(b.path, "/foo") + strict_eq(b.script_root, "") + strict_eq(b.host, "example.com") + + b = EnvironBuilder(path="/foo", base_url="http://example.com/bar") + strict_eq(b.base_url, "http://example.com/bar/") + strict_eq(b.path, "/foo") + strict_eq(b.script_root, "/bar") + strict_eq(b.host, "example.com") + + b.host = "localhost" + strict_eq(b.base_url, "http://localhost/bar/") + b.base_url = "http://localhost:8080/" + strict_eq(b.host, "localhost:8080") + strict_eq(b.server_name, "localhost") strict_eq(b.server_port, 8080) - b.host = 'foo.invalid' - b.url_scheme = 'https' - b.script_root = '/test' + b.host = "foo.invalid" + b.url_scheme = "https" + b.script_root = "/test" env = b.get_environ() - strict_eq(env['SERVER_NAME'], 'foo.invalid') - strict_eq(env['SERVER_PORT'], '443') - strict_eq(env['SCRIPT_NAME'], '/test') - strict_eq(env['PATH_INFO'], '/foo') - strict_eq(env['HTTP_HOST'], 'foo.invalid') - strict_eq(env['wsgi.url_scheme'], 'https') - strict_eq(b.base_url, 'https://foo.invalid/test/') + strict_eq(env["SERVER_NAME"], "foo.invalid") + strict_eq(env["SERVER_PORT"], "443") + strict_eq(env["SCRIPT_NAME"], "/test") + strict_eq(env["PATH_INFO"], "/foo") + strict_eq(env["HTTP_HOST"], "foo.invalid") + strict_eq(env["wsgi.url_scheme"], "https") + strict_eq(b.base_url, "https://foo.invalid/test/") def test_environ_builder_content_type(): builder = EnvironBuilder() assert builder.content_type is None - builder.method = 'POST' + builder.method = "POST" assert builder.content_type is None - builder.method = 'PUT' + builder.method = "PUT" assert builder.content_type is None - builder.method = 'PATCH' + builder.method = "PATCH" assert builder.content_type is None - builder.method = 'DELETE' + builder.method = "DELETE" assert builder.content_type is None - builder.method = 'GET' + builder.method = "GET" assert builder.content_type is None - builder.form['foo'] = 'bar' - assert builder.content_type == 'application/x-www-form-urlencoded' - builder.files.add_file('blafasel', BytesIO(b'foo'), 'test.txt') - assert builder.content_type == 'multipart/form-data' + builder.form["foo"] = "bar" + assert builder.content_type == "application/x-www-form-urlencoded" + builder.files.add_file("blafasel", BytesIO(b"foo"), "test.txt") + assert builder.content_type == "multipart/form-data" req = builder.get_request() - strict_eq(req.form['foo'], u'bar') - strict_eq(req.files['blafasel'].read(), b'foo') + strict_eq(req.form["foo"], u"bar") + strict_eq(req.files["blafasel"].read(), b"foo") def test_environ_builder_stream_switch(): - d = MultiDict(dict(foo=u'bar', blub=u'blah', hu=u'hum')) + d = MultiDict(dict(foo=u"bar", blub=u"blah", hu=u"hum")) for use_tempfile in False, True: stream, length, boundary = stream_encode_multipart( - d, use_tempfile, threshold=150) + d, use_tempfile, threshold=150 + ) assert isinstance(stream, BytesIO) != use_tempfile - form = parse_form_data({'wsgi.input': stream, 'CONTENT_LENGTH': str(length), - 'CONTENT_TYPE': 'multipart/form-data; boundary="%s"' % - boundary})[1] + form = parse_form_data( + { + "wsgi.input": stream, + "CONTENT_LENGTH": str(length), + "CONTENT_TYPE": 'multipart/form-data; boundary="%s"' % boundary, + } + )[1] strict_eq(form, d) stream.close() def test_environ_builder_unicode_file_mix(): for use_tempfile in False, True: - f = FileStorage(BytesIO(u'\N{SNOWMAN}'.encode('utf-8')), - 'snowman.txt') - d = MultiDict(dict(f=f, s=u'\N{SNOWMAN}')) + f = FileStorage(BytesIO(u"\N{SNOWMAN}".encode("utf-8")), "snowman.txt") + d = MultiDict(dict(f=f, s=u"\N{SNOWMAN}")) stream, length, boundary = stream_encode_multipart( - d, use_tempfile, threshold=150) + d, use_tempfile, threshold=150 + ) assert isinstance(stream, BytesIO) != use_tempfile - _, form, files = parse_form_data({ - 'wsgi.input': stream, - 'CONTENT_LENGTH': str(length), - 'CONTENT_TYPE': 'multipart/form-data; boundary="%s"' % - boundary - }) - strict_eq(form['s'], u'\N{SNOWMAN}') - strict_eq(files['f'].name, 'f') - strict_eq(files['f'].filename, u'snowman.txt') - strict_eq(files['f'].read(), - u'\N{SNOWMAN}'.encode('utf-8')) + _, form, files = parse_form_data( + { + "wsgi.input": stream, + "CONTENT_LENGTH": str(length), + "CONTENT_TYPE": 'multipart/form-data; boundary="%s"' % boundary, + } + ) + strict_eq(form["s"], u"\N{SNOWMAN}") + strict_eq(files["f"].name, "f") + strict_eq(files["f"].filename, u"snowman.txt") + strict_eq(files["f"].read(), u"\N{SNOWMAN}".encode("utf-8")) stream.close() def test_create_environ(): - env = create_environ('/foo?bar=baz', 'http://example.org/') + env = create_environ("/foo?bar=baz", "http://example.org/") expected = { - 'wsgi.multiprocess': False, - 'wsgi.version': (1, 0), - 'wsgi.run_once': False, - 'wsgi.errors': sys.stderr, - 'wsgi.multithread': False, - 'wsgi.url_scheme': 'http', - 'SCRIPT_NAME': '', - 'SERVER_NAME': 'example.org', - 'REQUEST_METHOD': 'GET', - 'HTTP_HOST': 'example.org', - 'PATH_INFO': '/foo', - 'SERVER_PORT': '80', - 'SERVER_PROTOCOL': 'HTTP/1.1', - 'QUERY_STRING': 'bar=baz' + "wsgi.multiprocess": False, + "wsgi.version": (1, 0), + "wsgi.run_once": False, + "wsgi.errors": sys.stderr, + "wsgi.multithread": False, + "wsgi.url_scheme": "http", + "SCRIPT_NAME": "", + "SERVER_NAME": "example.org", + "REQUEST_METHOD": "GET", + "HTTP_HOST": "example.org", + "PATH_INFO": "/foo", + "SERVER_PORT": "80", + "SERVER_PROTOCOL": "HTTP/1.1", + "QUERY_STRING": "bar=baz", } for key, value in iteritems(expected): assert env[key] == value - strict_eq(env['wsgi.input'].read(0), b'') - strict_eq(create_environ('/foo', 'http://example.com/')['SCRIPT_NAME'], '') + strict_eq(env["wsgi.input"].read(0), b"") + strict_eq(create_environ("/foo", "http://example.com/")["SCRIPT_NAME"], "") def test_create_environ_query_string_error(): with pytest.raises(ValueError): - create_environ('/foo?bar=baz', query_string={'a': 'b'}) + create_environ("/foo?bar=baz", query_string={"a": "b"}) def test_builder_from_environ(): @@ -341,7 +355,7 @@ def test_builder_from_environ(): base_url="https://example.com/base", query_string={"name": "Werkzeug"}, data={"foo": "bar"}, - headers={"X-Foo": "bar"} + headers={"X-Foo": "bar"}, ) builder = EnvironBuilder.from_environ(environ) try: @@ -355,39 +369,38 @@ def test_file_closing(): closed = [] class SpecialInput(object): - def read(self, size): - return '' + return "" def close(self): closed.append(self) - create_environ(data={'foo': SpecialInput()}) + create_environ(data={"foo": SpecialInput()}) strict_eq(len(closed), 1) builder = EnvironBuilder() - builder.files.add_file('blah', SpecialInput()) + builder.files.add_file("blah", SpecialInput()) builder.close() strict_eq(len(closed), 2) def test_follow_redirect(): - env = create_environ('/', base_url='http://localhost') + env = create_environ("/", base_url="http://localhost") c = Client(redirect_with_get_app) appiter, code, headers = c.open(environ_overrides=env, follow_redirects=True) - strict_eq(code, '200 OK') - strict_eq(b''.join(appiter), b'current url: http://localhost/some/redirect/') + strict_eq(code, "200 OK") + strict_eq(b"".join(appiter), b"current url: http://localhost/some/redirect/") # Test that the :cls:`Client` is aware of user defined response wrappers c = Client(redirect_with_get_app, response_wrapper=BaseResponse) - resp = c.get('/', follow_redirects=True) + resp = c.get("/", follow_redirects=True) strict_eq(resp.status_code, 200) - strict_eq(resp.data, b'current url: http://localhost/some/redirect/') + strict_eq(resp.data, b"current url: http://localhost/some/redirect/") # test with URL other than '/' to make sure redirected URL's are correct c = Client(redirect_with_get_app, response_wrapper=BaseResponse) - resp = c.get('/first/request', follow_redirects=True) + resp = c.get("/first/request", follow_redirects=True) strict_eq(resp.status_code, 200) - strict_eq(resp.data, b'current url: http://localhost/some/redirect/') + strict_eq(resp.data, b"current url: http://localhost/some/redirect/") def test_follow_local_redirect(): @@ -396,21 +409,20 @@ def test_follow_local_redirect(): def local_redirect_app(environ, start_response): req = Request(environ) - if '/from/location' in req.url: - response = redirect('/to/location', Response=LocalResponse) + if "/from/location" in req.url: + response = redirect("/to/location", Response=LocalResponse) else: - response = Response('current path: %s' % req.path) + response = Response("current path: %s" % req.path) return response(environ, start_response) c = Client(local_redirect_app, response_wrapper=BaseResponse) - resp = c.get('/from/location', follow_redirects=True) + resp = c.get("/from/location", follow_redirects=True) strict_eq(resp.status_code, 200) - strict_eq(resp.data, b'current path: /to/location') + strict_eq(resp.data, b"current path: /to/location") @pytest.mark.parametrize( - ("code", "keep"), - ((302, False), (301, False), (307, True), (308, True)), + ("code", "keep"), ((302, False), (301, False), (307, True), (308, True)) ) def test_follow_redirect_body(code, keep): @Request.application @@ -430,42 +442,42 @@ def test_follow_redirect_body(code, keep): c = Client(app, response_wrapper=BaseResponse) response = c.post( - "/", - follow_redirects=True, - data={"foo": "bar"}, - headers={"X-Foo": "bar"}, + "/", follow_redirects=True, data={"foo": "bar"}, headers={"X-Foo": "bar"} ) assert response.status_code == 200 assert response.data == b"current url: http://localhost/some/redirect/" def test_follow_external_redirect(): - env = create_environ('/', base_url='http://localhost') + env = create_environ("/", base_url="http://localhost") c = Client(external_redirect_demo_app) - pytest.raises(RuntimeError, lambda: - c.get(environ_overrides=env, follow_redirects=True)) + pytest.raises( + RuntimeError, lambda: c.get(environ_overrides=env, follow_redirects=True) + ) def test_follow_external_redirect_on_same_subdomain(): - env = create_environ('/', base_url='http://example.com') + env = create_environ("/", base_url="http://example.com") c = Client(external_subdomain_redirect_demo_app, allow_subdomain_redirects=True) c.get(environ_overrides=env, follow_redirects=True) # check that this does not work for real external domains - env = create_environ('/', base_url='http://localhost') - pytest.raises(RuntimeError, lambda: - c.get(environ_overrides=env, follow_redirects=True)) + env = create_environ("/", base_url="http://localhost") + pytest.raises( + RuntimeError, lambda: c.get(environ_overrides=env, follow_redirects=True) + ) # check that subdomain redirects fail if no `allow_subdomain_redirects` is applied c = Client(external_subdomain_redirect_demo_app) - pytest.raises(RuntimeError, lambda: - c.get(environ_overrides=env, follow_redirects=True)) + pytest.raises( + RuntimeError, lambda: c.get(environ_overrides=env, follow_redirects=True) + ) def test_follow_redirect_loop(): c = Client(redirect_loop_app, response_wrapper=BaseResponse) with pytest.raises(ClientRedirectError): - c.get('/', follow_redirects=True) + c.get("/", follow_redirects=True) def test_follow_redirect_non_root_base_url(): @@ -477,7 +489,9 @@ def test_follow_redirect_non_root_base_url(): return Response(request.path) c = Client(app, response_wrapper=Response) - response = c.get("/redirect", base_url="http://localhost/other", follow_redirects=True) + response = c.get( + "/redirect", base_url="http://localhost/other", follow_redirects=True + ) assert response.data == b"/done" @@ -508,89 +522,84 @@ def test_follow_redirect_exhaust_intermediate(): def test_path_info_script_name_unquoting(): def test_app(environ, start_response): - start_response('200 OK', [('Content-Type', 'text/plain')]) - return [environ['PATH_INFO'] + '\n' + environ['SCRIPT_NAME']] + start_response("200 OK", [("Content-Type", "text/plain")]) + return [environ["PATH_INFO"] + "\n" + environ["SCRIPT_NAME"]] + c = Client(test_app, response_wrapper=BaseResponse) - resp = c.get('/foo%40bar') - strict_eq(resp.data, b'/foo@bar\n') + resp = c.get("/foo%40bar") + strict_eq(resp.data, b"/foo@bar\n") c = Client(test_app, response_wrapper=BaseResponse) - resp = c.get('/foo%40bar', 'http://localhost/bar%40baz') - strict_eq(resp.data, b'/foo@bar\n/bar@baz') + resp = c.get("/foo%40bar", "http://localhost/bar%40baz") + strict_eq(resp.data, b"/foo@bar\n/bar@baz") def test_multi_value_submit(): c = Client(multi_value_post_app, response_wrapper=BaseResponse) - data = { - 'field': ['val1', 'val2'] - } - resp = c.post('/', data=data) + data = {"field": ["val1", "val2"]} + resp = c.post("/", data=data) strict_eq(resp.status_code, 200) c = Client(multi_value_post_app, response_wrapper=BaseResponse) - data = MultiDict({ - 'field': ['val1', 'val2'] - }) - resp = c.post('/', data=data) + data = MultiDict({"field": ["val1", "val2"]}) + resp = c.post("/", data=data) strict_eq(resp.status_code, 200) def test_iri_support(): - b = EnvironBuilder(u'/föö-bar', base_url=u'http://☃.net/') - strict_eq(b.path, '/f%C3%B6%C3%B6-bar') - strict_eq(b.base_url, 'http://xn--n3h.net/') + b = EnvironBuilder(u"/föö-bar", base_url=u"http://☃.net/") + strict_eq(b.path, "/f%C3%B6%C3%B6-bar") + strict_eq(b.base_url, "http://xn--n3h.net/") -@pytest.mark.parametrize('buffered', (True, False)) -@pytest.mark.parametrize('iterable', (True, False)) +@pytest.mark.parametrize("buffered", (True, False)) +@pytest.mark.parametrize("iterable", (True, False)) def test_run_wsgi_apps(buffered, iterable): leaked_data = [] def simple_app(environ, start_response): - start_response('200 OK', [('Content-Type', 'text/html')]) - return ['Hello World!'] + start_response("200 OK", [("Content-Type", "text/html")]) + return ["Hello World!"] def yielding_app(environ, start_response): - start_response('200 OK', [('Content-Type', 'text/html')]) - yield 'Hello ' - yield 'World!' + start_response("200 OK", [("Content-Type", "text/html")]) + yield "Hello " + yield "World!" def late_start_response(environ, start_response): - yield 'Hello ' - yield 'World' - start_response('200 OK', [('Content-Type', 'text/html')]) - yield '!' + yield "Hello " + yield "World" + start_response("200 OK", [("Content-Type", "text/html")]) + yield "!" def depends_on_close(environ, start_response): - leaked_data.append('harhar') - start_response('200 OK', [('Content-Type', 'text/html')]) + leaked_data.append("harhar") + start_response("200 OK", [("Content-Type", "text/html")]) class Rv(object): - def __iter__(self): - yield 'Hello ' - yield 'World' - yield '!' + yield "Hello " + yield "World" + yield "!" def close(self): - assert leaked_data.pop() == 'harhar' + assert leaked_data.pop() == "harhar" return Rv() - for app in (simple_app, yielding_app, late_start_response, - depends_on_close): + for app in (simple_app, yielding_app, late_start_response, depends_on_close): if iterable: app = iterable_middleware(app) app_iter, status, headers = run_wsgi_app(app, {}, buffered=buffered) - strict_eq(status, '200 OK') - strict_eq(list(headers), [('Content-Type', 'text/html')]) - strict_eq(''.join(app_iter), 'Hello World!') + strict_eq(status, "200 OK") + strict_eq(list(headers), [("Content-Type", "text/html")]) + strict_eq("".join(app_iter), "Hello World!") - if hasattr(app_iter, 'close'): + if hasattr(app_iter, "close"): app_iter.close() assert not leaked_data -@pytest.mark.parametrize('buffered', (True, False)) -@pytest.mark.parametrize('iterable', (True, False)) +@pytest.mark.parametrize("buffered", (True, False)) +@pytest.mark.parametrize("iterable", (True, False)) def test_lazy_start_response_empty_response_app(buffered, iterable): @implements_iterator class app: @@ -601,15 +610,15 @@ def test_lazy_start_response_empty_response_app(buffered, iterable): return self def __next__(self): - self.start_response('200 OK', [('Content-Type', 'text/html')]) + self.start_response("200 OK", [("Content-Type", "text/html")]) raise StopIteration if iterable: app = iterable_middleware(app) app_iter, status, headers = run_wsgi_app(app, {}, buffered=buffered) - strict_eq(status, '200 OK') - strict_eq(list(headers), [('Content-Type', 'text/html')]) - strict_eq(''.join(app_iter), '') + strict_eq(status, "200 OK") + strict_eq(list(headers), [("Content-Type", "text/html")]) + strict_eq("".join(app_iter), "") def test_run_wsgi_app_closing_iterator(): @@ -617,7 +626,6 @@ def test_run_wsgi_app_closing_iterator(): @implements_iterator class CloseIter(object): - def __init__(self): self.iterated = False @@ -631,39 +639,41 @@ def test_run_wsgi_app_closing_iterator(): if self.iterated: raise StopIteration() self.iterated = True - return 'bar' + return "bar" def bar(environ, start_response): - start_response('200 OK', [('Content-Type', 'text/plain')]) + start_response("200 OK", [("Content-Type", "text/plain")]) return CloseIter() app_iter, status, headers = run_wsgi_app(bar, {}) - assert status == '200 OK' - assert list(headers) == [('Content-Type', 'text/plain')] - assert next(app_iter) == 'bar' + assert status == "200 OK" + assert list(headers) == [("Content-Type", "text/plain")] + assert next(app_iter) == "bar" pytest.raises(StopIteration, partial(next, app_iter)) app_iter.close() - assert run_wsgi_app(bar, {}, True)[0] == ['bar'] + assert run_wsgi_app(bar, {}, True)[0] == ["bar"] assert len(got_close) == 2 def iterable_middleware(app): - '''Guarantee that the app returns an iterable''' + """Guarantee that the app returns an iterable""" + def inner(environ, start_response): rv = app(environ, start_response) class Iterable(object): - def __iter__(self): return iter(rv) - if hasattr(rv, 'close'): + if hasattr(rv, "close"): + def close(self): rv.close() return Iterable() + return inner @@ -671,15 +681,17 @@ def test_multiple_cookies(): @Request.application def test_app(request): response = Response(repr(sorted(request.cookies.items()))) - response.set_cookie(u'test1', b'foo') - response.set_cookie(u'test2', b'bar') + response.set_cookie(u"test1", b"foo") + response.set_cookie(u"test2", b"bar") return response + client = Client(test_app, Response) - resp = client.get('/') - strict_eq(resp.data, b'[]') - resp = client.get('/') - strict_eq(resp.data, - to_bytes(repr([('test1', u'foo'), ('test2', u'bar')]), 'ascii')) + resp = client.get("/") + strict_eq(resp.data, b"[]") + resp = client.get("/") + strict_eq( + resp.data, to_bytes(repr([("test1", u"foo"), ("test2", u"bar")]), "ascii") + ) def test_correct_open_invocation_on_redirect(): @@ -688,58 +700,59 @@ def test_correct_open_invocation_on_redirect(): def open(self, *args, **kwargs): self.counter += 1 - env = kwargs.setdefault('environ_overrides', {}) - env['werkzeug._foo'] = self.counter + env = kwargs.setdefault("environ_overrides", {}) + env["werkzeug._foo"] = self.counter return Client.open(self, *args, **kwargs) @Request.application def test_app(request): - return Response(str(request.environ['werkzeug._foo'])) + return Response(str(request.environ["werkzeug._foo"])) c = MyClient(test_app, response_wrapper=Response) - strict_eq(c.get('/').data, b'1') - strict_eq(c.get('/').data, b'2') - strict_eq(c.get('/').data, b'3') + strict_eq(c.get("/").data, b"1") + strict_eq(c.get("/").data, b"2") + strict_eq(c.get("/").data, b"3") def test_correct_encoding(): - req = Request.from_values(u'/\N{SNOWMAN}', u'http://example.com/foo') - strict_eq(req.script_root, u'/foo') - strict_eq(req.path, u'/\N{SNOWMAN}') + req = Request.from_values(u"/\N{SNOWMAN}", u"http://example.com/foo") + strict_eq(req.script_root, u"/foo") + strict_eq(req.path, u"/\N{SNOWMAN}") def test_full_url_requests_with_args(): - base = 'http://example.com/' + base = "http://example.com/" @Request.application def test_app(request): - return Response(request.args['x']) + return Response(request.args["x"]) + client = Client(test_app, Response) - resp = client.get('/?x=42', base) - strict_eq(resp.data, b'42') - resp = client.get('http://www.example.com/?x=23', base) - strict_eq(resp.data, b'23') + resp = client.get("/?x=42", base) + strict_eq(resp.data, b"42") + resp = client.get("http://www.example.com/?x=23", base) + strict_eq(resp.data, b"23") def test_delete_requests_with_form(): @Request.application def test_app(request): - return Response(request.form.get('x', None)) + return Response(request.form.get("x", None)) client = Client(test_app, Response) - resp = client.delete('/', data={'x': 42}) - strict_eq(resp.data, b'42') + resp = client.delete("/", data={"x": 42}) + strict_eq(resp.data, b"42") def test_post_with_file_descriptor(tmpdir): c = Client(Response(), response_wrapper=Response) - f = tmpdir.join('some-file.txt') - f.write('foo') - with open(f.strpath, mode='rt') as data: - resp = c.post('/', data=data) + f = tmpdir.join("some-file.txt") + f.write("foo") + with open(f.strpath, mode="rt") as data: + resp = c.post("/", data=data) strict_eq(resp.status_code, 200) - with open(f.strpath, mode='rb') as data: - resp = c.post('/', data=data) + with open(f.strpath, mode="rb") as data: + resp = c.post("/", data=data) strict_eq(resp.status_code, 200) @@ -747,13 +760,14 @@ def test_content_type(): @Request.application def test_app(request): return Response(request.content_type) + client = Client(test_app, Response) - resp = client.get('/', data=b'testing', mimetype='text/css') - strict_eq(resp.data, b'text/css; charset=utf-8') + resp = client.get("/", data=b"testing", mimetype="text/css") + strict_eq(resp.data, b"text/css; charset=utf-8") - resp = client.get('/', data=b'testing', mimetype='application/octet-stream') - strict_eq(resp.data, b'application/octet-stream') + resp = client.get("/", data=b"testing", mimetype="application/octet-stream") + strict_eq(resp.data, b"application/octet-stream") def test_raw_request_uri(): diff --git a/tests/test_urls.py b/tests/test_urls.py index ae3db554..71b80c33 100644 --- a/tests/test_urls.py +++ b/tests/test_urls.py @@ -10,394 +10,426 @@ """ import pytest -from tests import strict_eq - -from werkzeug.datastructures import OrderedMultiDict +from . import strict_eq from werkzeug import urls -from werkzeug._compat import text_type, NativeStringIO, BytesIO +from werkzeug._compat import BytesIO +from werkzeug._compat import NativeStringIO +from werkzeug._compat import text_type +from werkzeug.datastructures import OrderedMultiDict def test_parsing(): - url = urls.url_parse('http://anon:hunter2@[2001:db8:0:1]:80/a/b/c') - assert url.netloc == 'anon:hunter2@[2001:db8:0:1]:80' - assert url.username == 'anon' - assert url.password == 'hunter2' + url = urls.url_parse("http://anon:hunter2@[2001:db8:0:1]:80/a/b/c") + assert url.netloc == "anon:hunter2@[2001:db8:0:1]:80" + assert url.username == "anon" + assert url.password == "hunter2" assert url.port == 80 - assert url.ascii_host == '2001:db8:0:1' + assert url.ascii_host == "2001:db8:0:1" assert url.get_file_location() == (None, None) # no file scheme -@pytest.mark.parametrize('implicit_format', (True, False)) -@pytest.mark.parametrize('localhost', ('127.0.0.1', '::1', 'localhost')) +@pytest.mark.parametrize("implicit_format", (True, False)) +@pytest.mark.parametrize("localhost", ("127.0.0.1", "::1", "localhost")) def test_fileurl_parsing_windows(implicit_format, localhost, monkeypatch): if implicit_format: pathformat = None - monkeypatch.setattr('os.name', 'nt') + monkeypatch.setattr("os.name", "nt") else: - pathformat = 'windows' - monkeypatch.delattr('os.name') # just to make sure it won't get used + pathformat = "windows" + monkeypatch.delattr("os.name") # just to make sure it won't get used - url = urls.url_parse('file:///C:/Documents and Settings/Foobar/stuff.txt') - assert url.netloc == '' - assert url.scheme == 'file' - assert url.get_file_location(pathformat) == \ - (None, r'C:\Documents and Settings\Foobar\stuff.txt') + url = urls.url_parse("file:///C:/Documents and Settings/Foobar/stuff.txt") + assert url.netloc == "" + assert url.scheme == "file" + assert url.get_file_location(pathformat) == ( + None, + r"C:\Documents and Settings\Foobar\stuff.txt", + ) - url = urls.url_parse('file://///server.tld/file.txt') - assert url.get_file_location(pathformat) == ('server.tld', r'file.txt') + url = urls.url_parse("file://///server.tld/file.txt") + assert url.get_file_location(pathformat) == ("server.tld", r"file.txt") - url = urls.url_parse('file://///server.tld') - assert url.get_file_location(pathformat) == ('server.tld', '') + url = urls.url_parse("file://///server.tld") + assert url.get_file_location(pathformat) == ("server.tld", "") - url = urls.url_parse('file://///%s' % localhost) - assert url.get_file_location(pathformat) == (None, '') + url = urls.url_parse("file://///%s" % localhost) + assert url.get_file_location(pathformat) == (None, "") - url = urls.url_parse('file://///%s/file.txt' % localhost) - assert url.get_file_location(pathformat) == (None, r'file.txt') + url = urls.url_parse("file://///%s/file.txt" % localhost) + assert url.get_file_location(pathformat) == (None, r"file.txt") def test_replace(): - url = urls.url_parse('http://de.wikipedia.org/wiki/Troll') - strict_eq(url.replace(query='foo=bar'), - urls.url_parse('http://de.wikipedia.org/wiki/Troll?foo=bar')) - strict_eq(url.replace(scheme='https'), - urls.url_parse('https://de.wikipedia.org/wiki/Troll')) + url = urls.url_parse("http://de.wikipedia.org/wiki/Troll") + strict_eq( + url.replace(query="foo=bar"), + urls.url_parse("http://de.wikipedia.org/wiki/Troll?foo=bar"), + ) + strict_eq( + url.replace(scheme="https"), + urls.url_parse("https://de.wikipedia.org/wiki/Troll"), + ) def test_quoting(): - strict_eq(urls.url_quote(u'\xf6\xe4\xfc'), '%C3%B6%C3%A4%C3%BC') + strict_eq(urls.url_quote(u"\xf6\xe4\xfc"), "%C3%B6%C3%A4%C3%BC") strict_eq(urls.url_unquote(urls.url_quote(u'#%="\xf6')), u'#%="\xf6') - strict_eq(urls.url_quote_plus('foo bar'), 'foo+bar') - strict_eq(urls.url_unquote_plus('foo+bar'), u'foo bar') - strict_eq(urls.url_quote_plus('foo+bar'), 'foo%2Bbar') - strict_eq(urls.url_unquote_plus('foo%2Bbar'), u'foo+bar') - strict_eq(urls.url_encode({b'a': None, b'b': b'foo bar'}), 'b=foo+bar') - strict_eq(urls.url_encode({u'a': None, u'b': u'foo bar'}), 'b=foo+bar') - strict_eq(urls.url_fix(u'http://de.wikipedia.org/wiki/Elf (Begriffsklärung)'), - 'http://de.wikipedia.org/wiki/Elf%20(Begriffskl%C3%A4rung)') - strict_eq(urls.url_quote_plus(42), '42') - strict_eq(urls.url_quote(b'\xff'), '%FF') + strict_eq(urls.url_quote_plus("foo bar"), "foo+bar") + strict_eq(urls.url_unquote_plus("foo+bar"), u"foo bar") + strict_eq(urls.url_quote_plus("foo+bar"), "foo%2Bbar") + strict_eq(urls.url_unquote_plus("foo%2Bbar"), u"foo+bar") + strict_eq(urls.url_encode({b"a": None, b"b": b"foo bar"}), "b=foo+bar") + strict_eq(urls.url_encode({u"a": None, u"b": u"foo bar"}), "b=foo+bar") + strict_eq( + urls.url_fix(u"http://de.wikipedia.org/wiki/Elf (Begriffsklärung)"), + "http://de.wikipedia.org/wiki/Elf%20(Begriffskl%C3%A4rung)", + ) + strict_eq(urls.url_quote_plus(42), "42") + strict_eq(urls.url_quote(b"\xff"), "%FF") def test_bytes_unquoting(): - strict_eq(urls.url_unquote(urls.url_quote( - u'#%="\xf6', charset='latin1'), charset=None), b'#%="\xf6') + strict_eq( + urls.url_unquote(urls.url_quote(u'#%="\xf6', charset="latin1"), charset=None), + b'#%="\xf6', + ) def test_url_decoding(): - x = urls.url_decode(b'foo=42&bar=23&uni=H%C3%A4nsel') - strict_eq(x['foo'], u'42') - strict_eq(x['bar'], u'23') - strict_eq(x['uni'], u'Hänsel') + x = urls.url_decode(b"foo=42&bar=23&uni=H%C3%A4nsel") + strict_eq(x["foo"], u"42") + strict_eq(x["bar"], u"23") + strict_eq(x["uni"], u"Hänsel") - x = urls.url_decode(b'foo=42;bar=23;uni=H%C3%A4nsel', separator=b';') - strict_eq(x['foo'], u'42') - strict_eq(x['bar'], u'23') - strict_eq(x['uni'], u'Hänsel') + x = urls.url_decode(b"foo=42;bar=23;uni=H%C3%A4nsel", separator=b";") + strict_eq(x["foo"], u"42") + strict_eq(x["bar"], u"23") + strict_eq(x["uni"], u"Hänsel") - x = urls.url_decode(b'%C3%9Ch=H%C3%A4nsel', decode_keys=True) - strict_eq(x[u'Üh'], u'Hänsel') + x = urls.url_decode(b"%C3%9Ch=H%C3%A4nsel", decode_keys=True) + strict_eq(x[u"Üh"], u"Hänsel") def test_url_bytes_decoding(): - x = urls.url_decode(b'foo=42&bar=23&uni=H%C3%A4nsel', charset=None) - strict_eq(x[b'foo'], b'42') - strict_eq(x[b'bar'], b'23') - strict_eq(x[b'uni'], u'Hänsel'.encode('utf-8')) + x = urls.url_decode(b"foo=42&bar=23&uni=H%C3%A4nsel", charset=None) + strict_eq(x[b"foo"], b"42") + strict_eq(x[b"bar"], b"23") + strict_eq(x[b"uni"], u"Hänsel".encode("utf-8")) def test_streamed_url_decoding(): - item1 = u'a' * 100000 - item2 = u'b' * 400 - string = ('a=%s&b=%s&c=%s' % (item1, item2, item2)).encode('ascii') - gen = urls.url_decode_stream(BytesIO(string), limit=len(string), - return_iterator=True) - strict_eq(next(gen), ('a', item1)) - strict_eq(next(gen), ('b', item2)) - strict_eq(next(gen), ('c', item2)) + item1 = u"a" * 100000 + item2 = u"b" * 400 + string = ("a=%s&b=%s&c=%s" % (item1, item2, item2)).encode("ascii") + gen = urls.url_decode_stream( + BytesIO(string), limit=len(string), return_iterator=True + ) + strict_eq(next(gen), ("a", item1)) + strict_eq(next(gen), ("b", item2)) + strict_eq(next(gen), ("c", item2)) pytest.raises(StopIteration, lambda: next(gen)) def test_stream_decoding_string_fails(): - pytest.raises(TypeError, urls.url_decode_stream, 'testing') + pytest.raises(TypeError, urls.url_decode_stream, "testing") def test_url_encoding(): - strict_eq(urls.url_encode({'foo': 'bar 45'}), 'foo=bar+45') - d = {'foo': 1, 'bar': 23, 'blah': u'Hänsel'} - strict_eq(urls.url_encode(d, sort=True), 'bar=23&blah=H%C3%A4nsel&foo=1') - strict_eq(urls.url_encode(d, sort=True, separator=u';'), 'bar=23;blah=H%C3%A4nsel;foo=1') + strict_eq(urls.url_encode({"foo": "bar 45"}), "foo=bar+45") + d = {"foo": 1, "bar": 23, "blah": u"Hänsel"} + strict_eq(urls.url_encode(d, sort=True), "bar=23&blah=H%C3%A4nsel&foo=1") + strict_eq( + urls.url_encode(d, sort=True, separator=u";"), "bar=23;blah=H%C3%A4nsel;foo=1" + ) def test_sorted_url_encode(): - strict_eq(urls.url_encode({u"a": 42, u"b": 23, 1: 1, 2: 2}, - sort=True, key=lambda i: text_type(i[0])), '1=1&2=2&a=42&b=23') - strict_eq(urls.url_encode({u'A': 1, u'a': 2, u'B': 3, 'b': 4}, sort=True, - key=lambda x: x[0].lower() + x[0]), 'A=1&a=2&B=3&b=4') + strict_eq( + urls.url_encode( + {u"a": 42, u"b": 23, 1: 1, 2: 2}, sort=True, key=lambda i: text_type(i[0]) + ), + "1=1&2=2&a=42&b=23", + ) + strict_eq( + urls.url_encode( + {u"A": 1, u"a": 2, u"B": 3, "b": 4}, + sort=True, + key=lambda x: x[0].lower() + x[0], + ), + "A=1&a=2&B=3&b=4", + ) def test_streamed_url_encoding(): out = NativeStringIO() - urls.url_encode_stream({'foo': 'bar 45'}, out) - strict_eq(out.getvalue(), 'foo=bar+45') + urls.url_encode_stream({"foo": "bar 45"}, out) + strict_eq(out.getvalue(), "foo=bar+45") - d = {'foo': 1, 'bar': 23, 'blah': u'Hänsel'} + d = {"foo": 1, "bar": 23, "blah": u"Hänsel"} out = NativeStringIO() urls.url_encode_stream(d, out, sort=True) - strict_eq(out.getvalue(), 'bar=23&blah=H%C3%A4nsel&foo=1') + strict_eq(out.getvalue(), "bar=23&blah=H%C3%A4nsel&foo=1") out = NativeStringIO() - urls.url_encode_stream(d, out, sort=True, separator=u';') - strict_eq(out.getvalue(), 'bar=23;blah=H%C3%A4nsel;foo=1') + urls.url_encode_stream(d, out, sort=True, separator=u";") + strict_eq(out.getvalue(), "bar=23;blah=H%C3%A4nsel;foo=1") gen = urls.url_encode_stream(d, sort=True) - strict_eq(next(gen), 'bar=23') - strict_eq(next(gen), 'blah=H%C3%A4nsel') - strict_eq(next(gen), 'foo=1') + strict_eq(next(gen), "bar=23") + strict_eq(next(gen), "blah=H%C3%A4nsel") + strict_eq(next(gen), "foo=1") pytest.raises(StopIteration, lambda: next(gen)) def test_url_fixing(): - x = urls.url_fix(u'http://de.wikipedia.org/wiki/Elf (Begriffskl\xe4rung)') - assert x == 'http://de.wikipedia.org/wiki/Elf%20(Begriffskl%C3%A4rung)' + x = urls.url_fix(u"http://de.wikipedia.org/wiki/Elf (Begriffskl\xe4rung)") + assert x == "http://de.wikipedia.org/wiki/Elf%20(Begriffskl%C3%A4rung)" x = urls.url_fix("http://just.a.test/$-_.+!*'(),") assert x == "http://just.a.test/$-_.+!*'()," - x = urls.url_fix('http://höhöhö.at/höhöhö/hähähä') - assert x == r'http://xn--hhh-snabb.at/h%C3%B6h%C3%B6h%C3%B6/h%C3%A4h%C3%A4h%C3%A4' + x = urls.url_fix("http://höhöhö.at/höhöhö/hähähä") + assert x == r"http://xn--hhh-snabb.at/h%C3%B6h%C3%B6h%C3%B6/h%C3%A4h%C3%A4h%C3%A4" def test_url_fixing_filepaths(): - x = urls.url_fix(r'file://C:\Users\Administrator\My Documents\ÑÈáÇíí') - assert x == (r'file:///C%3A/Users/Administrator/My%20Documents/' - r'%C3%91%C3%88%C3%A1%C3%87%C3%AD%C3%AD') + x = urls.url_fix(r"file://C:\Users\Administrator\My Documents\ÑÈáÇíí") + assert x == ( + r"file:///C%3A/Users/Administrator/My%20Documents/" + r"%C3%91%C3%88%C3%A1%C3%87%C3%AD%C3%AD" + ) - a = urls.url_fix(r'file:/C:/') - b = urls.url_fix(r'file://C:/') - c = urls.url_fix(r'file:///C:/') - assert a == b == c == r'file:///C%3A/' + a = urls.url_fix(r"file:/C:/") + b = urls.url_fix(r"file://C:/") + c = urls.url_fix(r"file:///C:/") + assert a == b == c == r"file:///C%3A/" - x = urls.url_fix(r'file://host/sub/path') - assert x == r'file://host/sub/path' + x = urls.url_fix(r"file://host/sub/path") + assert x == r"file://host/sub/path" - x = urls.url_fix(r'file:///') - assert x == r'file:///' + x = urls.url_fix(r"file:///") + assert x == r"file:///" def test_url_fixing_qs(): - x = urls.url_fix(b'http://example.com/?foo=%2f%2f') - assert x == 'http://example.com/?foo=%2f%2f' + x = urls.url_fix(b"http://example.com/?foo=%2f%2f") + assert x == "http://example.com/?foo=%2f%2f" - x = urls.url_fix('http://acronyms.thefreedictionary.com/' - 'Algebraic+Methods+of+Solving+the+Schr%C3%B6dinger+Equation') - assert x == ('http://acronyms.thefreedictionary.com/' - 'Algebraic+Methods+of+Solving+the+Schr%C3%B6dinger+Equation') + x = urls.url_fix( + "http://acronyms.thefreedictionary.com/" + "Algebraic+Methods+of+Solving+the+Schr%C3%B6dinger+Equation" + ) + assert x == ( + "http://acronyms.thefreedictionary.com/" + "Algebraic+Methods+of+Solving+the+Schr%C3%B6dinger+Equation" + ) def test_iri_support(): - strict_eq(urls.uri_to_iri('http://xn--n3h.net/'), - u'http://\u2603.net/') + strict_eq(urls.uri_to_iri("http://xn--n3h.net/"), u"http://\u2603.net/") strict_eq( - urls.uri_to_iri(b'http://%C3%BCser:p%C3%A4ssword@xn--n3h.net/p%C3%A5th'), - u'http://\xfcser:p\xe4ssword@\u2603.net/p\xe5th') - strict_eq(urls.iri_to_uri(u'http://☃.net/'), 'http://xn--n3h.net/') + urls.uri_to_iri(b"http://%C3%BCser:p%C3%A4ssword@xn--n3h.net/p%C3%A5th"), + u"http://\xfcser:p\xe4ssword@\u2603.net/p\xe5th", + ) + strict_eq(urls.iri_to_uri(u"http://☃.net/"), "http://xn--n3h.net/") strict_eq( - urls.iri_to_uri(u'http://üser:pässword@☃.net/påth'), - 'http://%C3%BCser:p%C3%A4ssword@xn--n3h.net/p%C3%A5th') + urls.iri_to_uri(u"http://üser:pässword@☃.net/påth"), + "http://%C3%BCser:p%C3%A4ssword@xn--n3h.net/p%C3%A5th", + ) - strict_eq(urls.uri_to_iri('http://test.com/%3Fmeh?foo=%26%2F'), - u'http://test.com/%3Fmeh?foo=%26%2F') + strict_eq( + urls.uri_to_iri("http://test.com/%3Fmeh?foo=%26%2F"), + u"http://test.com/%3Fmeh?foo=%26%2F", + ) # this should work as well, might break on 2.4 because of a broken # idna codec - strict_eq(urls.uri_to_iri(b'/foo'), u'/foo') - strict_eq(urls.iri_to_uri(u'/foo'), '/foo') + strict_eq(urls.uri_to_iri(b"/foo"), u"/foo") + strict_eq(urls.iri_to_uri(u"/foo"), "/foo") - strict_eq(urls.iri_to_uri(u'http://föö.com:8080/bam/baz'), - 'http://xn--f-1gaa.com:8080/bam/baz') + strict_eq( + urls.iri_to_uri(u"http://föö.com:8080/bam/baz"), + "http://xn--f-1gaa.com:8080/bam/baz", + ) def test_iri_safe_conversion(): - strict_eq(urls.iri_to_uri(u'magnet:?foo=bar'), - 'magnet:?foo=bar') - strict_eq(urls.iri_to_uri(u'itms-service://?foo=bar'), - 'itms-service:?foo=bar') - strict_eq(urls.iri_to_uri(u'itms-service://?foo=bar', - safe_conversion=True), - 'itms-service://?foo=bar') + strict_eq(urls.iri_to_uri(u"magnet:?foo=bar"), "magnet:?foo=bar") + strict_eq(urls.iri_to_uri(u"itms-service://?foo=bar"), "itms-service:?foo=bar") + strict_eq( + urls.iri_to_uri(u"itms-service://?foo=bar", safe_conversion=True), + "itms-service://?foo=bar", + ) def test_iri_safe_quoting(): - uri = 'http://xn--f-1gaa.com/%2F%25?q=%C3%B6&x=%3D%25#%25' - iri = u'http://föö.com/%2F%25?q=ö&x=%3D%25#%25' + uri = "http://xn--f-1gaa.com/%2F%25?q=%C3%B6&x=%3D%25#%25" + iri = u"http://föö.com/%2F%25?q=ö&x=%3D%25#%25" strict_eq(urls.uri_to_iri(uri), iri) strict_eq(urls.iri_to_uri(urls.uri_to_iri(uri)), uri) def test_ordered_multidict_encoding(): d = OrderedMultiDict() - d.add('foo', 1) - d.add('foo', 2) - d.add('foo', 3) - d.add('bar', 0) - d.add('foo', 4) - assert urls.url_encode(d) == 'foo=1&foo=2&foo=3&bar=0&foo=4' + d.add("foo", 1) + d.add("foo", 2) + d.add("foo", 3) + d.add("bar", 0) + d.add("foo", 4) + assert urls.url_encode(d) == "foo=1&foo=2&foo=3&bar=0&foo=4" def test_multidict_encoding(): d = OrderedMultiDict() - d.add('2013-10-10T23:26:05.657975+0000', '2013-10-10T23:26:05.657975+0000') - assert urls.url_encode( - d) == '2013-10-10T23%3A26%3A05.657975%2B0000=2013-10-10T23%3A26%3A05.657975%2B0000' + d.add("2013-10-10T23:26:05.657975+0000", "2013-10-10T23:26:05.657975+0000") + assert ( + urls.url_encode(d) + == "2013-10-10T23%3A26%3A05.657975%2B0000=2013-10-10T23%3A26%3A05.657975%2B0000" + ) def test_href(): - x = urls.Href('http://www.example.com/') - strict_eq(x(u'foo'), 'http://www.example.com/foo') - strict_eq(x.foo(u'bar'), 'http://www.example.com/foo/bar') - strict_eq(x.foo(u'bar', x=42), 'http://www.example.com/foo/bar?x=42') - strict_eq(x.foo(u'bar', class_=42), 'http://www.example.com/foo/bar?class=42') - strict_eq(x.foo(u'bar', {u'class': 42}), 'http://www.example.com/foo/bar?class=42') + x = urls.Href("http://www.example.com/") + strict_eq(x(u"foo"), "http://www.example.com/foo") + strict_eq(x.foo(u"bar"), "http://www.example.com/foo/bar") + strict_eq(x.foo(u"bar", x=42), "http://www.example.com/foo/bar?x=42") + strict_eq(x.foo(u"bar", class_=42), "http://www.example.com/foo/bar?class=42") + strict_eq(x.foo(u"bar", {u"class": 42}), "http://www.example.com/foo/bar?class=42") pytest.raises(AttributeError, lambda: x.__blah__) - x = urls.Href('blah') - strict_eq(x.foo(u'bar'), 'blah/foo/bar') + x = urls.Href("blah") + strict_eq(x.foo(u"bar"), "blah/foo/bar") pytest.raises(TypeError, x.foo, {u"foo": 23}, x=42) - x = urls.Href('') - strict_eq(x('foo'), 'foo') + x = urls.Href("") + strict_eq(x("foo"), "foo") def test_href_url_join(): - x = urls.Href(u'test') - assert x(u'foo:bar') == u'test/foo:bar' - assert x(u'http://example.com/') == u'test/http://example.com/' - assert x.a() == u'test/a' + x = urls.Href(u"test") + assert x(u"foo:bar") == u"test/foo:bar" + assert x(u"http://example.com/") == u"test/http://example.com/" + assert x.a() == u"test/a" def test_href_past_root(): - base_href = urls.Href('http://www.blagga.com/1/2/3') - strict_eq(base_href('../foo'), 'http://www.blagga.com/1/2/foo') - strict_eq(base_href('../../foo'), 'http://www.blagga.com/1/foo') - strict_eq(base_href('../../../foo'), 'http://www.blagga.com/foo') - strict_eq(base_href('../../../../foo'), 'http://www.blagga.com/foo') - strict_eq(base_href('../../../../../foo'), 'http://www.blagga.com/foo') - strict_eq(base_href('../../../../../../foo'), 'http://www.blagga.com/foo') + base_href = urls.Href("http://www.blagga.com/1/2/3") + strict_eq(base_href("../foo"), "http://www.blagga.com/1/2/foo") + strict_eq(base_href("../../foo"), "http://www.blagga.com/1/foo") + strict_eq(base_href("../../../foo"), "http://www.blagga.com/foo") + strict_eq(base_href("../../../../foo"), "http://www.blagga.com/foo") + strict_eq(base_href("../../../../../foo"), "http://www.blagga.com/foo") + strict_eq(base_href("../../../../../../foo"), "http://www.blagga.com/foo") def test_url_unquote_plus_unicode(): # was broken in 0.6 - strict_eq(urls.url_unquote_plus(u'\x6d'), u'\x6d') - assert type(urls.url_unquote_plus(u'\x6d')) is text_type + strict_eq(urls.url_unquote_plus(u"\x6d"), u"\x6d") + assert type(urls.url_unquote_plus(u"\x6d")) is text_type def test_quoting_of_local_urls(): - rv = urls.iri_to_uri(u'/foo\x8f') - strict_eq(rv, '/foo%C2%8F') + rv = urls.iri_to_uri(u"/foo\x8f") + strict_eq(rv, "/foo%C2%8F") assert type(rv) is str def test_url_attributes(): - rv = urls.url_parse('http://foo%3a:bar%3a@[::1]:80/123?x=y#frag') - strict_eq(rv.scheme, 'http') - strict_eq(rv.auth, 'foo%3a:bar%3a') - strict_eq(rv.username, u'foo:') - strict_eq(rv.password, u'bar:') - strict_eq(rv.raw_username, 'foo%3a') - strict_eq(rv.raw_password, 'bar%3a') - strict_eq(rv.host, '::1') + rv = urls.url_parse("http://foo%3a:bar%3a@[::1]:80/123?x=y#frag") + strict_eq(rv.scheme, "http") + strict_eq(rv.auth, "foo%3a:bar%3a") + strict_eq(rv.username, u"foo:") + strict_eq(rv.password, u"bar:") + strict_eq(rv.raw_username, "foo%3a") + strict_eq(rv.raw_password, "bar%3a") + strict_eq(rv.host, "::1") assert rv.port == 80 - strict_eq(rv.path, '/123') - strict_eq(rv.query, 'x=y') - strict_eq(rv.fragment, 'frag') + strict_eq(rv.path, "/123") + strict_eq(rv.query, "x=y") + strict_eq(rv.fragment, "frag") - rv = urls.url_parse(u'http://\N{SNOWMAN}.com/') - strict_eq(rv.host, u'\N{SNOWMAN}.com') - strict_eq(rv.ascii_host, 'xn--n3h.com') + rv = urls.url_parse(u"http://\N{SNOWMAN}.com/") + strict_eq(rv.host, u"\N{SNOWMAN}.com") + strict_eq(rv.ascii_host, "xn--n3h.com") def test_url_attributes_bytes(): - rv = urls.url_parse(b'http://foo%3a:bar%3a@[::1]:80/123?x=y#frag') - strict_eq(rv.scheme, b'http') - strict_eq(rv.auth, b'foo%3a:bar%3a') - strict_eq(rv.username, u'foo:') - strict_eq(rv.password, u'bar:') - strict_eq(rv.raw_username, b'foo%3a') - strict_eq(rv.raw_password, b'bar%3a') - strict_eq(rv.host, b'::1') + rv = urls.url_parse(b"http://foo%3a:bar%3a@[::1]:80/123?x=y#frag") + strict_eq(rv.scheme, b"http") + strict_eq(rv.auth, b"foo%3a:bar%3a") + strict_eq(rv.username, u"foo:") + strict_eq(rv.password, u"bar:") + strict_eq(rv.raw_username, b"foo%3a") + strict_eq(rv.raw_password, b"bar%3a") + strict_eq(rv.host, b"::1") assert rv.port == 80 - strict_eq(rv.path, b'/123') - strict_eq(rv.query, b'x=y') - strict_eq(rv.fragment, b'frag') + strict_eq(rv.path, b"/123") + strict_eq(rv.query, b"x=y") + strict_eq(rv.fragment, b"frag") def test_url_joining(): - strict_eq(urls.url_join('/foo', '/bar'), '/bar') - strict_eq(urls.url_join('http://example.com/foo', '/bar'), - 'http://example.com/bar') - strict_eq(urls.url_join('file:///tmp/', 'test.html'), - 'file:///tmp/test.html') - strict_eq(urls.url_join('file:///tmp/x', 'test.html'), - 'file:///tmp/test.html') - strict_eq(urls.url_join('file:///tmp/x', '../../../x.html'), - 'file:///x.html') + strict_eq(urls.url_join("/foo", "/bar"), "/bar") + strict_eq(urls.url_join("http://example.com/foo", "/bar"), "http://example.com/bar") + strict_eq(urls.url_join("file:///tmp/", "test.html"), "file:///tmp/test.html") + strict_eq(urls.url_join("file:///tmp/x", "test.html"), "file:///tmp/test.html") + strict_eq(urls.url_join("file:///tmp/x", "../../../x.html"), "file:///x.html") def test_partial_unencoded_decode(): - ref = u'foo=정상처리'.encode('euc-kr') - x = urls.url_decode(ref, charset='euc-kr') - strict_eq(x['foo'], u'정상처리') + ref = u"foo=정상처리".encode("euc-kr") + x = urls.url_decode(ref, charset="euc-kr") + strict_eq(x["foo"], u"정상처리") def test_iri_to_uri_idempotence_ascii_only(): - uri = u'http://www.idempoten.ce' + uri = u"http://www.idempoten.ce" uri = urls.iri_to_uri(uri) assert urls.iri_to_uri(uri) == uri def test_iri_to_uri_idempotence_non_ascii(): - uri = u'http://\N{SNOWMAN}/\N{SNOWMAN}' + uri = u"http://\N{SNOWMAN}/\N{SNOWMAN}" uri = urls.iri_to_uri(uri) assert urls.iri_to_uri(uri) == uri def test_uri_to_iri_idempotence_ascii_only(): - uri = 'http://www.idempoten.ce' + uri = "http://www.idempoten.ce" uri = urls.uri_to_iri(uri) assert urls.uri_to_iri(uri) == uri def test_uri_to_iri_idempotence_non_ascii(): - uri = 'http://xn--n3h/%E2%98%83' + uri = "http://xn--n3h/%E2%98%83" uri = urls.uri_to_iri(uri) assert urls.uri_to_iri(uri) == uri def test_iri_to_uri_to_iri(): - iri = u'http://föö.com/' + iri = u"http://föö.com/" uri = urls.iri_to_uri(iri) assert urls.uri_to_iri(uri) == iri def test_uri_to_iri_to_uri(): - uri = 'http://xn--f-rgao.com/%C3%9E' + uri = "http://xn--f-rgao.com/%C3%9E" iri = urls.uri_to_iri(uri) assert urls.iri_to_uri(iri) == uri def test_uri_iri_normalization(): - uri = 'http://xn--f-rgao.com/%E2%98%90/fred?utf8=%E2%9C%93' - iri = u'http://föñ.com/\N{BALLOT BOX}/fred?utf8=\u2713' + uri = "http://xn--f-rgao.com/%E2%98%90/fred?utf8=%E2%9C%93" + iri = u"http://föñ.com/\N{BALLOT BOX}/fred?utf8=\u2713" tests = [ - u'http://föñ.com/\N{BALLOT BOX}/fred?utf8=\u2713', - u'http://xn--f-rgao.com/\u2610/fred?utf8=\N{CHECK MARK}', - b'http://xn--f-rgao.com/%E2%98%90/fred?utf8=%E2%9C%93', - u'http://xn--f-rgao.com/%E2%98%90/fred?utf8=%E2%9C%93', - u'http://föñ.com/\u2610/fred?utf8=%E2%9C%93', - b'http://xn--f-rgao.com/\xe2\x98\x90/fred?utf8=\xe2\x9c\x93', + u"http://föñ.com/\N{BALLOT BOX}/fred?utf8=\u2713", + u"http://xn--f-rgao.com/\u2610/fred?utf8=\N{CHECK MARK}", + b"http://xn--f-rgao.com/%E2%98%90/fred?utf8=%E2%9C%93", + u"http://xn--f-rgao.com/%E2%98%90/fred?utf8=%E2%9C%93", + u"http://föñ.com/\u2610/fred?utf8=%E2%9C%93", + b"http://xn--f-rgao.com/\xe2\x98\x90/fred?utf8=\xe2\x9c\x93", ] for test in tests: diff --git a/tests/test_utils.py b/tests/test_utils.py index 0435a251..f288edea 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,43 +8,46 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import pytest - -from datetime import datetime import inspect +from datetime import datetime + +import pytest from werkzeug import utils +from werkzeug._compat import text_type from werkzeug.datastructures import Headers -from werkzeug.http import parse_date, http_date -from werkzeug.wrappers import BaseResponse +from werkzeug.http import http_date +from werkzeug.http import parse_date from werkzeug.test import Client -from werkzeug._compat import text_type +from werkzeug.wrappers import BaseResponse def test_redirect(): - resp = utils.redirect(u'/füübär') - assert b'/f%C3%BC%C3%BCb%C3%A4r' in resp.get_data() - assert resp.headers['Location'] == '/f%C3%BC%C3%BCb%C3%A4r' + resp = utils.redirect(u"/füübär") + assert b"/f%C3%BC%C3%BCb%C3%A4r" in resp.get_data() + assert resp.headers["Location"] == "/f%C3%BC%C3%BCb%C3%A4r" assert resp.status_code == 302 - resp = utils.redirect(u'http://☃.net/', 307) - assert b'http://xn--n3h.net/' in resp.get_data() - assert resp.headers['Location'] == 'http://xn--n3h.net/' + resp = utils.redirect(u"http://☃.net/", 307) + assert b"http://xn--n3h.net/" in resp.get_data() + assert resp.headers["Location"] == "http://xn--n3h.net/" assert resp.status_code == 307 - resp = utils.redirect('http://example.com/', 305) - assert resp.headers['Location'] == 'http://example.com/' + resp = utils.redirect("http://example.com/", 305) + assert resp.headers["Location"] == "http://example.com/" assert resp.status_code == 305 def test_redirect_xss(): location = 'http://example.com/?xss="><script>alert(1)</script>' resp = utils.redirect(location) - assert b'<script>alert(1)</script>' not in resp.get_data() + assert b"<script>alert(1)</script>" not in resp.get_data() location = 'http://example.com/?xss="onmouseover="alert(1)' resp = utils.redirect(location) - assert b'href="http://example.com/?xss="onmouseover="alert(1)"' not in resp.get_data() + assert ( + b'href="http://example.com/?xss="onmouseover="alert(1)"' not in resp.get_data() + ) def test_redirect_with_custom_response_class(): @@ -55,17 +58,17 @@ def test_redirect_with_custom_response_class(): resp = utils.redirect(location, Response=MyResponse) assert isinstance(resp, MyResponse) - assert resp.headers['Location'] == location + assert resp.headers["Location"] == location def test_cached_property(): foo = [] class A(object): - def prop(self): foo.append(42) return 42 + prop = utils.cached_property(prop) a = A() @@ -77,11 +80,11 @@ def test_cached_property(): foo = [] class A(object): - def _prop(self): foo.append(42) return 42 - prop = utils.cached_property(_prop, name='prop') + + prop = utils.cached_property(_prop, name="prop") del _prop a = A() @@ -93,198 +96,222 @@ def test_cached_property(): def test_can_set_cached_property(): class A(object): - @utils.cached_property def _prop(self): - return 'cached_property return value' + return "cached_property return value" a = A() - a._prop = 'value' - assert a._prop == 'value' + a._prop = "value" + assert a._prop == "value" def test_inspect_treats_cached_property_as_property(): class A(object): - @utils.cached_property def _prop(self): - return 'cached_property return value' + return "cached_property return value" attrs = inspect.classify_class_attrs(A) for attr in attrs: - if attr.name == '_prop': + if attr.name == "_prop": break - assert attr.kind == 'property' + assert attr.kind == "property" def test_environ_property(): class A(object): - environ = {'string': 'abc', 'number': '42'} - - string = utils.environ_property('string') - missing = utils.environ_property('missing', 'spam') - read_only = utils.environ_property('number') - number = utils.environ_property('number', load_func=int) - broken_number = utils.environ_property('broken_number', load_func=int) - date = utils.environ_property('date', None, parse_date, http_date, - read_only=False) - foo = utils.environ_property('foo') + environ = {"string": "abc", "number": "42"} + + string = utils.environ_property("string") + missing = utils.environ_property("missing", "spam") + read_only = utils.environ_property("number") + number = utils.environ_property("number", load_func=int) + broken_number = utils.environ_property("broken_number", load_func=int) + date = utils.environ_property( + "date", None, parse_date, http_date, read_only=False + ) + foo = utils.environ_property("foo") a = A() - assert a.string == 'abc' - assert a.missing == 'spam' + assert a.string == "abc" + assert a.missing == "spam" def test_assign(): - a.read_only = 'something' + a.read_only = "something" + pytest.raises(AttributeError, test_assign) assert a.number == 42 assert a.broken_number is None assert a.date is None a.date = datetime(2008, 1, 22, 10, 0, 0, 0) - assert a.environ['date'] == 'Tue, 22 Jan 2008 10:00:00 GMT' + assert a.environ["date"] == "Tue, 22 Jan 2008 10:00:00 GMT" def test_escape(): class Foo(str): - def __html__(self): return text_type(self) - assert utils.escape(None) == '' - assert utils.escape(42) == '42' - assert utils.escape('<>') == '<>' - assert utils.escape('"foo"') == '"foo"' - assert utils.escape(Foo('<foo>')) == '<foo>' + + assert utils.escape(None) == "" + assert utils.escape(42) == "42" + assert utils.escape("<>") == "<>" + assert utils.escape('"foo"') == ""foo"" + assert utils.escape(Foo("<foo>")) == "<foo>" def test_unescape(): - assert utils.unescape('<ä>') == u'<ä>' + assert utils.unescape("<ä>") == u"<ä>" def test_import_string(): from datetime import date from werkzeug.debug import DebuggedApplication - assert utils.import_string('datetime.date') is date - assert utils.import_string(u'datetime.date') is date - assert utils.import_string('datetime:date') is date - assert utils.import_string('XXXXXXXXXXXX', True) is None - assert utils.import_string('datetime.XXXXXXXXXXXX', True) is None - assert utils.import_string(u'werkzeug.debug.DebuggedApplication') is DebuggedApplication - pytest.raises(ImportError, utils.import_string, 'XXXXXXXXXXXXXXXX') - pytest.raises(ImportError, utils.import_string, 'datetime.XXXXXXXXXX') + + assert utils.import_string("datetime.date") is date + assert utils.import_string(u"datetime.date") is date + assert utils.import_string("datetime:date") is date + assert utils.import_string("XXXXXXXXXXXX", True) is None + assert utils.import_string("datetime.XXXXXXXXXXXX", True) is None + assert ( + utils.import_string(u"werkzeug.debug.DebuggedApplication") + is DebuggedApplication + ) + pytest.raises(ImportError, utils.import_string, "XXXXXXXXXXXXXXXX") + pytest.raises(ImportError, utils.import_string, "datetime.XXXXXXXXXX") def test_import_string_provides_traceback(tmpdir, monkeypatch): monkeypatch.syspath_prepend(str(tmpdir)) # Couple of packages - dir_a = tmpdir.mkdir('a') - dir_b = tmpdir.mkdir('b') + dir_a = tmpdir.mkdir("a") + dir_b = tmpdir.mkdir("b") # Totally packages, I promise - dir_a.join('__init__.py').write('') - dir_b.join('__init__.py').write('') + dir_a.join("__init__.py").write("") + dir_b.join("__init__.py").write("") # 'aa.a' that depends on 'bb.b', which in turn has a broken import - dir_a.join('aa.py').write('from b import bb') - dir_b.join('bb.py').write('from os import a_typo') + dir_a.join("aa.py").write("from b import bb") + dir_b.join("bb.py").write("from os import a_typo") # Do we get all the useful information in the traceback? with pytest.raises(ImportError) as baz_exc: - utils.import_string('a.aa') - traceback = ''.join((str(line) for line in baz_exc.traceback)) - assert 'bb.py\':1' in traceback # a bit different than typical python tb - assert 'from os import a_typo' in traceback + utils.import_string("a.aa") + traceback = "".join((str(line) for line in baz_exc.traceback)) + assert "bb.py':1" in traceback # a bit different than typical python tb + assert "from os import a_typo" in traceback def test_import_string_attribute_error(tmpdir, monkeypatch): monkeypatch.syspath_prepend(str(tmpdir)) - tmpdir.join('foo_test.py').write('from bar_test import value') - tmpdir.join('bar_test.py').write('raise AttributeError("screw you!")') + tmpdir.join("foo_test.py").write("from bar_test import value") + tmpdir.join("bar_test.py").write('raise AttributeError("screw you!")') with pytest.raises(AttributeError) as foo_exc: - utils.import_string('foo_test') - assert 'screw you!' in str(foo_exc) + utils.import_string("foo_test") + assert "screw you!" in str(foo_exc) with pytest.raises(AttributeError) as bar_exc: - utils.import_string('bar_test') - assert 'screw you!' in str(bar_exc) + utils.import_string("bar_test") + assert "screw you!" in str(bar_exc) def test_find_modules(): - assert list(utils.find_modules('werkzeug.debug')) == [ - 'werkzeug.debug.console', 'werkzeug.debug.repr', - 'werkzeug.debug.tbtools' + assert list(utils.find_modules("werkzeug.debug")) == [ + "werkzeug.debug.console", + "werkzeug.debug.repr", + "werkzeug.debug.tbtools", ] def test_html_builder(): html = utils.html xhtml = utils.xhtml - assert html.p('Hello World') == '<p>Hello World</p>' - assert html.a('Test', href='#') == '<a href="#">Test</a>' - assert html.br() == '<br>' - assert xhtml.br() == '<br />' - assert html.img(src='foo') == '<img src="foo">' - assert xhtml.img(src='foo') == '<img src="foo" />' - assert html.html(html.head( - html.title('foo'), - html.script(type='text/javascript') - )) == ( + assert html.p("Hello World") == "<p>Hello World</p>" + assert html.a("Test", href="#") == '<a href="#">Test</a>' + assert html.br() == "<br>" + assert xhtml.br() == "<br />" + assert html.img(src="foo") == '<img src="foo">' + assert xhtml.img(src="foo") == '<img src="foo" />' + assert html.html( + html.head(html.title("foo"), html.script(type="text/javascript")) + ) == ( '<html><head><title>foo</title><script type="text/javascript">' - '</script></head></html>' + "</script></head></html>" ) - assert html('<foo>') == '<foo>' - assert html.input(disabled=True) == '<input disabled>' + assert html("<foo>") == "<foo>" + assert html.input(disabled=True) == "<input disabled>" assert xhtml.input(disabled=True) == '<input disabled="disabled" />' - assert html.input(disabled='') == '<input>' - assert xhtml.input(disabled='') == '<input />' - assert html.input(disabled=None) == '<input>' - assert xhtml.input(disabled=None) == '<input />' - assert html.script('alert("Hello World");') == \ - '<script>alert("Hello World");</script>' - assert xhtml.script('alert("Hello World");') == \ - '<script>/*<![CDATA[*/alert("Hello World");/*]]>*/</script>' + assert html.input(disabled="") == "<input>" + assert xhtml.input(disabled="") == "<input />" + assert html.input(disabled=None) == "<input>" + assert xhtml.input(disabled=None) == "<input />" + assert ( + html.script('alert("Hello World");') == '<script>alert("Hello World");</script>' + ) + assert ( + xhtml.script('alert("Hello World");') + == '<script>/*<![CDATA[*/alert("Hello World");/*]]>*/</script>' + ) def test_validate_arguments(): - take_none = lambda: None - take_two = lambda a, b: None - take_two_one_default = lambda a, b=0: None + def take_none(): + pass + + def take_two(a, b): + pass - assert utils.validate_arguments(take_two, (1, 2,), {}) == ((1, 2), {}) - assert utils.validate_arguments(take_two, (1,), {'b': 2}) == ((1, 2), {}) + def take_two_one_default(a, b=0): + pass + + assert utils.validate_arguments(take_two, (1, 2), {}) == ((1, 2), {}) + assert utils.validate_arguments(take_two, (1,), {"b": 2}) == ((1, 2), {}) assert utils.validate_arguments(take_two_one_default, (1,), {}) == ((1, 0), {}) assert utils.validate_arguments(take_two_one_default, (1, 2), {}) == ((1, 2), {}) - pytest.raises(utils.ArgumentValidationError, - utils.validate_arguments, take_two, (), {}) + pytest.raises( + utils.ArgumentValidationError, utils.validate_arguments, take_two, (), {} + ) - assert utils.validate_arguments(take_none, (1, 2,), {'c': 3}) == ((), {}) - pytest.raises(utils.ArgumentValidationError, - utils.validate_arguments, take_none, (1,), {}, drop_extra=False) - pytest.raises(utils.ArgumentValidationError, - utils.validate_arguments, take_none, (), {'a': 1}, drop_extra=False) + assert utils.validate_arguments(take_none, (1, 2), {"c": 3}) == ((), {}) + pytest.raises( + utils.ArgumentValidationError, + utils.validate_arguments, + take_none, + (1,), + {}, + drop_extra=False, + ) + pytest.raises( + utils.ArgumentValidationError, + utils.validate_arguments, + take_none, + (), + {"a": 1}, + drop_extra=False, + ) def test_header_set_duplication_bug(): - headers = Headers([ - ('Content-Type', 'text/html'), - ('Foo', 'bar'), - ('Blub', 'blah') - ]) - headers['blub'] = 'hehe' - headers['blafasel'] = 'humm' - assert headers == Headers([ - ('Content-Type', 'text/html'), - ('Foo', 'bar'), - ('blub', 'hehe'), - ('blafasel', 'humm') - ]) + headers = Headers([("Content-Type", "text/html"), ("Foo", "bar"), ("Blub", "blah")]) + headers["blub"] = "hehe" + headers["blafasel"] = "humm" + assert headers == Headers( + [ + ("Content-Type", "text/html"), + ("Foo", "bar"), + ("blub", "hehe"), + ("blafasel", "humm"), + ] + ) def test_append_slash_redirect(): def app(env, sr): return utils.append_slash_redirect(env)(env, sr) + client = Client(app, BaseResponse) - response = client.get('foo', base_url='http://example.org/app') + response = client.get("foo", base_url="http://example.org/app") assert response.status_code == 301 - assert response.headers['Location'] == 'http://example.org/app/foo/' + assert response.headers["Location"] == "http://example.org/app/foo/" def test_cached_property_doc(): @@ -292,15 +319,18 @@ def test_cached_property_doc(): def foo(): """testing""" return 42 - assert foo.__doc__ == 'testing' - assert foo.__name__ == 'foo' + + assert foo.__doc__ == "testing" + assert foo.__name__ == "foo" assert foo.__module__ == __name__ def test_secure_filename(): - assert utils.secure_filename('My cool movie.mov') == 'My_cool_movie.mov' - assert utils.secure_filename('../../../etc/passwd') == 'etc_passwd' - assert utils.secure_filename(u'i contain cool \xfcml\xe4uts.txt') == \ - 'i_contain_cool_umlauts.txt' - assert utils.secure_filename('__filename__') == 'filename' - assert utils.secure_filename('foo$&^*)bar') == 'foobar' + assert utils.secure_filename("My cool movie.mov") == "My_cool_movie.mov" + assert utils.secure_filename("../../../etc/passwd") == "etc_passwd" + assert ( + utils.secure_filename(u"i contain cool \xfcml\xe4uts.txt") + == "i_contain_cool_umlauts.txt" + ) + assert utils.secure_filename("__filename__") == "filename" + assert utils.secure_filename("foo$&^*)bar") == "foobar" diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 203bea6c..511d66de 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -11,33 +11,41 @@ import contextlib import json import os - -import pytest - import pickle +from datetime import datetime +from datetime import timedelta from io import BytesIO -from datetime import datetime, timedelta -from werkzeug._compat import iteritems -from tests import strict_eq +import pytest +from . import strict_eq from werkzeug import wrappers -from werkzeug.exceptions import SecurityError, RequestedRangeNotSatisfiable, \ - BadRequest +from werkzeug._compat import implements_iterator +from werkzeug._compat import iteritems +from werkzeug._compat import text_type +from werkzeug.datastructures import Accept +from werkzeug.datastructures import CharsetAccept +from werkzeug.datastructures import CombinedMultiDict +from werkzeug.datastructures import Headers +from werkzeug.datastructures import ImmutableList +from werkzeug.datastructures import ImmutableOrderedMultiDict +from werkzeug.datastructures import ImmutableTypeConversionDict +from werkzeug.datastructures import LanguageAccept +from werkzeug.datastructures import MIMEAccept +from werkzeug.datastructures import MultiDict +from werkzeug.exceptions import BadRequest +from werkzeug.exceptions import RequestedRangeNotSatisfiable +from werkzeug.exceptions import SecurityError from werkzeug.http import generate_etag +from werkzeug.test import Client +from werkzeug.test import create_environ +from werkzeug.test import run_wsgi_app from werkzeug.wrappers.json import JSONMixin -from werkzeug.wsgi import LimitedStream, wrap_file -from werkzeug.datastructures import ( - MultiDict, ImmutableOrderedMultiDict, - ImmutableList, ImmutableTypeConversionDict, CharsetAccept, - MIMEAccept, LanguageAccept, Accept, CombinedMultiDict, Headers, -) -from werkzeug.test import Client, create_environ, run_wsgi_app -from werkzeug._compat import implements_iterator, text_type +from werkzeug.wsgi import LimitedStream +from werkzeug.wsgi import wrap_file class RequestTestResponse(wrappers.BaseResponse): - """Subclass of the normal response class we use to test response and base classes. Has some methods to test if things in the response match. @@ -53,16 +61,20 @@ class RequestTestResponse(wrappers.BaseResponse): def request_demo_app(environ, start_response): request = wrappers.BaseRequest(environ) - assert 'werkzeug.request' in environ - start_response('200 OK', [('Content-Type', 'text/plain')]) - return [pickle.dumps({ - 'args': request.args, - 'args_as_list': list(request.args.lists()), - 'form': request.form, - 'form_as_list': list(request.form.lists()), - 'environ': prepare_environ_pickle(request.environ), - 'data': request.get_data() - })] + assert "werkzeug.request" in environ + start_response("200 OK", [("Content-Type", "text/plain")]) + return [ + pickle.dumps( + { + "args": request.args, + "args_as_list": list(request.args.lists()), + "form": request.form, + "form_as_list": list(request.form.lists()), + "environ": prepare_environ_pickle(request.environ), + "data": request.get_data(), + } + ) + ] def prepare_environ_pickle(environ): @@ -77,127 +89,132 @@ def prepare_environ_pickle(environ): def assert_environ(environ, method): - strict_eq(environ['REQUEST_METHOD'], method) - strict_eq(environ['PATH_INFO'], '/') - strict_eq(environ['SCRIPT_NAME'], '') - strict_eq(environ['SERVER_NAME'], 'localhost') - strict_eq(environ['wsgi.version'], (1, 0)) - strict_eq(environ['wsgi.url_scheme'], 'http') + strict_eq(environ["REQUEST_METHOD"], method) + strict_eq(environ["PATH_INFO"], "/") + strict_eq(environ["SCRIPT_NAME"], "") + strict_eq(environ["SERVER_NAME"], "localhost") + strict_eq(environ["wsgi.version"], (1, 0)) + strict_eq(environ["wsgi.url_scheme"], "http") def test_base_request(): client = Client(request_demo_app, RequestTestResponse) # get requests - response = client.get('/?foo=bar&foo=hehe') - strict_eq(response['args'], MultiDict([('foo', u'bar'), ('foo', u'hehe')])) - strict_eq(response['args_as_list'], [('foo', [u'bar', u'hehe'])]) - strict_eq(response['form'], MultiDict()) - strict_eq(response['form_as_list'], []) - strict_eq(response['data'], b'') - assert_environ(response['environ'], 'GET') + response = client.get("/?foo=bar&foo=hehe") + strict_eq(response["args"], MultiDict([("foo", u"bar"), ("foo", u"hehe")])) + strict_eq(response["args_as_list"], [("foo", [u"bar", u"hehe"])]) + strict_eq(response["form"], MultiDict()) + strict_eq(response["form_as_list"], []) + strict_eq(response["data"], b"") + assert_environ(response["environ"], "GET") # post requests with form data - response = client.post('/?blub=blah', data='foo=blub+hehe&blah=42', - content_type='application/x-www-form-urlencoded') - strict_eq(response['args'], MultiDict([('blub', u'blah')])) - strict_eq(response['args_as_list'], [('blub', [u'blah'])]) - strict_eq(response['form'], MultiDict([('foo', u'blub hehe'), ('blah', u'42')])) - strict_eq(response['data'], b'') + response = client.post( + "/?blub=blah", + data="foo=blub+hehe&blah=42", + content_type="application/x-www-form-urlencoded", + ) + strict_eq(response["args"], MultiDict([("blub", u"blah")])) + strict_eq(response["args_as_list"], [("blub", [u"blah"])]) + strict_eq(response["form"], MultiDict([("foo", u"blub hehe"), ("blah", u"42")])) + strict_eq(response["data"], b"") # currently we do not guarantee that the values are ordered correctly # for post data. # strict_eq(response['form_as_list'], [('foo', ['blub hehe']), ('blah', ['42'])]) - assert_environ(response['environ'], 'POST') + assert_environ(response["environ"], "POST") # patch requests with form data - response = client.patch('/?blub=blah', data='foo=blub+hehe&blah=42', - content_type='application/x-www-form-urlencoded') - strict_eq(response['args'], MultiDict([('blub', u'blah')])) - strict_eq(response['args_as_list'], [('blub', [u'blah'])]) - strict_eq(response['form'], - MultiDict([('foo', u'blub hehe'), ('blah', u'42')])) - strict_eq(response['data'], b'') - assert_environ(response['environ'], 'PATCH') + response = client.patch( + "/?blub=blah", + data="foo=blub+hehe&blah=42", + content_type="application/x-www-form-urlencoded", + ) + strict_eq(response["args"], MultiDict([("blub", u"blah")])) + strict_eq(response["args_as_list"], [("blub", [u"blah"])]) + strict_eq(response["form"], MultiDict([("foo", u"blub hehe"), ("blah", u"42")])) + strict_eq(response["data"], b"") + assert_environ(response["environ"], "PATCH") # post requests with json data json = b'{"foo": "bar", "blub": "blah"}' - response = client.post('/?a=b', data=json, content_type='application/json') - strict_eq(response['data'], json) - strict_eq(response['args'], MultiDict([('a', u'b')])) - strict_eq(response['form'], MultiDict()) + response = client.post("/?a=b", data=json, content_type="application/json") + strict_eq(response["data"], json) + strict_eq(response["args"], MultiDict([("a", u"b")])) + strict_eq(response["form"], MultiDict()) def test_query_string_is_bytes(): - req = wrappers.Request.from_values(u'/?foo=%2f') - strict_eq(req.query_string, b'foo=%2f') + req = wrappers.Request.from_values(u"/?foo=%2f") + strict_eq(req.query_string, b"foo=%2f") def test_request_repr(): - req = wrappers.Request.from_values('/foobar') + req = wrappers.Request.from_values("/foobar") assert "<Request 'http://localhost/foobar' [GET]>" == repr(req) # test with non-ascii characters - req = wrappers.Request.from_values('/привет') + req = wrappers.Request.from_values("/привет") assert "<Request 'http://localhost/привет' [GET]>" == repr(req) # test with unicode type for python 2 - req = wrappers.Request.from_values(u'/привет') + req = wrappers.Request.from_values(u"/привет") assert "<Request 'http://localhost/привет' [GET]>" == repr(req) def test_access_route(): - req = wrappers.Request.from_values(headers={ - 'X-Forwarded-For': '192.168.1.2, 192.168.1.1' - }) - req.environ['REMOTE_ADDR'] = '192.168.1.3' - assert req.access_route == ['192.168.1.2', '192.168.1.1'] - strict_eq(req.remote_addr, '192.168.1.3') + req = wrappers.Request.from_values( + headers={"X-Forwarded-For": "192.168.1.2, 192.168.1.1"} + ) + req.environ["REMOTE_ADDR"] = "192.168.1.3" + assert req.access_route == ["192.168.1.2", "192.168.1.1"] + strict_eq(req.remote_addr, "192.168.1.3") req = wrappers.Request.from_values() - req.environ['REMOTE_ADDR'] = '192.168.1.3' - strict_eq(list(req.access_route), ['192.168.1.3']) + req.environ["REMOTE_ADDR"] = "192.168.1.3" + strict_eq(list(req.access_route), ["192.168.1.3"]) def test_url_request_descriptors(): - req = wrappers.Request.from_values('/bar?foo=baz', 'http://example.com/test') - strict_eq(req.path, u'/bar') - strict_eq(req.full_path, u'/bar?foo=baz') - strict_eq(req.script_root, u'/test') - strict_eq(req.url, u'http://example.com/test/bar?foo=baz') - strict_eq(req.base_url, u'http://example.com/test/bar') - strict_eq(req.url_root, u'http://example.com/test/') - strict_eq(req.host_url, u'http://example.com/') - strict_eq(req.host, 'example.com') - strict_eq(req.scheme, 'http') + req = wrappers.Request.from_values("/bar?foo=baz", "http://example.com/test") + strict_eq(req.path, u"/bar") + strict_eq(req.full_path, u"/bar?foo=baz") + strict_eq(req.script_root, u"/test") + strict_eq(req.url, u"http://example.com/test/bar?foo=baz") + strict_eq(req.base_url, u"http://example.com/test/bar") + strict_eq(req.url_root, u"http://example.com/test/") + strict_eq(req.host_url, u"http://example.com/") + strict_eq(req.host, "example.com") + strict_eq(req.scheme, "http") - req = wrappers.Request.from_values('/bar?foo=baz', 'https://example.com/test') - strict_eq(req.scheme, 'https') + req = wrappers.Request.from_values("/bar?foo=baz", "https://example.com/test") + strict_eq(req.scheme, "https") def test_url_request_descriptors_query_quoting(): - next = 'http%3A%2F%2Fwww.example.com%2F%3Fnext%3D%2Fbaz%23my%3Dhash' - req = wrappers.Request.from_values('/bar?next=' + next, 'http://example.com/') - assert req.path == u'/bar' - strict_eq(req.full_path, u'/bar?next=' + next) - strict_eq(req.url, u'http://example.com/bar?next=' + next) + next = "http%3A%2F%2Fwww.example.com%2F%3Fnext%3D%2Fbaz%23my%3Dhash" + req = wrappers.Request.from_values("/bar?next=" + next, "http://example.com/") + assert req.path == u"/bar" + strict_eq(req.full_path, u"/bar?next=" + next) + strict_eq(req.url, u"http://example.com/bar?next=" + next) def test_url_request_descriptors_hosts(): - req = wrappers.Request.from_values('/bar?foo=baz', 'http://example.com/test') - req.trusted_hosts = ['example.com'] - strict_eq(req.path, u'/bar') - strict_eq(req.full_path, u'/bar?foo=baz') - strict_eq(req.script_root, u'/test') - strict_eq(req.url, u'http://example.com/test/bar?foo=baz') - strict_eq(req.base_url, u'http://example.com/test/bar') - strict_eq(req.url_root, u'http://example.com/test/') - strict_eq(req.host_url, u'http://example.com/') - strict_eq(req.host, 'example.com') - strict_eq(req.scheme, 'http') - - req = wrappers.Request.from_values('/bar?foo=baz', 'https://example.com/test') - strict_eq(req.scheme, 'https') - - req = wrappers.Request.from_values('/bar?foo=baz', 'http://example.com/test') - req.trusted_hosts = ['example.org'] + req = wrappers.Request.from_values("/bar?foo=baz", "http://example.com/test") + req.trusted_hosts = ["example.com"] + strict_eq(req.path, u"/bar") + strict_eq(req.full_path, u"/bar?foo=baz") + strict_eq(req.script_root, u"/test") + strict_eq(req.url, u"http://example.com/test/bar?foo=baz") + strict_eq(req.base_url, u"http://example.com/test/bar") + strict_eq(req.url_root, u"http://example.com/test/") + strict_eq(req.host_url, u"http://example.com/") + strict_eq(req.host, "example.com") + strict_eq(req.scheme, "http") + + req = wrappers.Request.from_values("/bar?foo=baz", "https://example.com/test") + strict_eq(req.scheme, "https") + + req = wrappers.Request.from_values("/bar?foo=baz", "http://example.com/test") + req.trusted_hosts = ["example.org"] pytest.raises(SecurityError, lambda: req.url) pytest.raises(SecurityError, lambda: req.base_url) pytest.raises(SecurityError, lambda: req.url_root) @@ -206,89 +223,106 @@ def test_url_request_descriptors_hosts(): def test_authorization_mixin(): - request = wrappers.Request.from_values(headers={ - 'Authorization': 'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==' - }) + request = wrappers.Request.from_values( + headers={"Authorization": "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="} + ) a = request.authorization - strict_eq(a.type, 'basic') - strict_eq(a.username, u'Aladdin') - strict_eq(a.password, u'open sesame') + strict_eq(a.type, "basic") + strict_eq(a.username, u"Aladdin") + strict_eq(a.password, u"open sesame") def test_authorization_with_unicode(): - request = wrappers.Request.from_values(headers={ - 'Authorization': 'Basic 0YDRg9GB0YHQutC40IE60JHRg9C60LLRiw==' - }) + request = wrappers.Request.from_values( + headers={"Authorization": "Basic 0YDRg9GB0YHQutC40IE60JHRg9C60LLRiw=="} + ) a = request.authorization - strict_eq(a.type, 'basic') - strict_eq(a.username, u'русскиЁ') - strict_eq(a.password, u'Буквы') + strict_eq(a.type, "basic") + strict_eq(a.username, u"русскиЁ") + strict_eq(a.password, u"Буквы") def test_stream_only_mixing(): request = wrappers.PlainRequest.from_values( - data=b'foo=blub+hehe', - content_type='application/x-www-form-urlencoded' + data=b"foo=blub+hehe", content_type="application/x-www-form-urlencoded" ) assert list(request.files.items()) == [] assert list(request.form.items()) == [] pytest.raises(AttributeError, lambda: request.data) - strict_eq(request.stream.read(), b'foo=blub+hehe') + strict_eq(request.stream.read(), b"foo=blub+hehe") def test_request_application(): @wrappers.Request.application def application(request): - return wrappers.Response('Hello World!') + return wrappers.Response("Hello World!") @wrappers.Request.application def failing_application(request): raise BadRequest() resp = wrappers.Response.from_app(application, create_environ()) - assert resp.data == b'Hello World!' + assert resp.data == b"Hello World!" assert resp.status_code == 200 resp = wrappers.Response.from_app(failing_application, create_environ()) - assert b'Bad Request' in resp.data + assert b"Bad Request" in resp.data assert resp.status_code == 400 def test_base_response(): # unicode - response = wrappers.BaseResponse(u'öäü') - strict_eq(response.get_data(), u'öäü'.encode('utf-8')) + response = wrappers.BaseResponse(u"öäü") + strict_eq(response.get_data(), u"öäü".encode("utf-8")) # writing - response = wrappers.Response('foo') - response.stream.write('bar') - strict_eq(response.get_data(), b'foobar') + response = wrappers.Response("foo") + response.stream.write("bar") + strict_eq(response.get_data(), b"foobar") # set cookie response = wrappers.BaseResponse() - response.set_cookie('foo', value='bar', max_age=60, expires=0, - path='/blub', domain='example.org', samesite='Strict') - strict_eq(response.headers.to_wsgi_list(), [ - ('Content-Type', 'text/plain; charset=utf-8'), - ('Set-Cookie', 'foo=bar; Domain=example.org; Expires=Thu, ' - '01-Jan-1970 00:00:00 GMT; Max-Age=60; Path=/blub; ' - 'SameSite=Strict') - ]) + response.set_cookie( + "foo", + value="bar", + max_age=60, + expires=0, + path="/blub", + domain="example.org", + samesite="Strict", + ) + strict_eq( + response.headers.to_wsgi_list(), + [ + ("Content-Type", "text/plain; charset=utf-8"), + ( + "Set-Cookie", + "foo=bar; Domain=example.org; Expires=Thu, " + "01-Jan-1970 00:00:00 GMT; Max-Age=60; Path=/blub; " + "SameSite=Strict", + ), + ], + ) # delete cookie response = wrappers.BaseResponse() - response.delete_cookie('foo') - strict_eq(response.headers.to_wsgi_list(), [ - ('Content-Type', 'text/plain; charset=utf-8'), - ('Set-Cookie', 'foo=; Expires=Thu, 01-Jan-1970 00:00:00 GMT; Max-Age=0; Path=/') - ]) + response.delete_cookie("foo") + strict_eq( + response.headers.to_wsgi_list(), + [ + ("Content-Type", "text/plain; charset=utf-8"), + ( + "Set-Cookie", + "foo=; Expires=Thu, 01-Jan-1970 00:00:00 GMT; Max-Age=0; Path=/", + ), + ], + ) # close call forwarding closed = [] @implements_iterator class Iterable(object): - def __next__(self): raise StopIteration() @@ -297,13 +331,12 @@ def test_base_response(): def close(self): closed.append(True) + response = wrappers.BaseResponse(Iterable()) response.call_on_close(lambda: closed.append(True)) - app_iter, status, headers = run_wsgi_app(response, - create_environ(), - buffered=True) - strict_eq(status, '200 OK') - strict_eq(''.join(app_iter), '') + app_iter, status, headers = run_wsgi_app(response, create_environ(), buffered=True) + strict_eq(status, "200 OK") + strict_eq("".join(app_iter), "") strict_eq(len(closed), 2) # with statement @@ -317,36 +350,36 @@ def test_base_response(): def test_response_status_codes(): response = wrappers.BaseResponse() response.status_code = 404 - strict_eq(response.status, '404 NOT FOUND') - response.status = '200 OK' + strict_eq(response.status, "404 NOT FOUND") + response.status = "200 OK" strict_eq(response.status_code, 200) - response.status = '999 WTF' + response.status = "999 WTF" strict_eq(response.status_code, 999) response.status_code = 588 strict_eq(response.status_code, 588) - strict_eq(response.status, '588 UNKNOWN') - response.status = 'wtf' + strict_eq(response.status, "588 UNKNOWN") + response.status = "wtf" strict_eq(response.status_code, 0) - strict_eq(response.status, '0 wtf') + strict_eq(response.status, "0 wtf") # invalid status codes with pytest.raises(ValueError) as empty_string_error: - wrappers.BaseResponse(None, '') - assert 'Empty status argument' in str(empty_string_error) + wrappers.BaseResponse(None, "") + assert "Empty status argument" in str(empty_string_error) with pytest.raises(TypeError) as invalid_type_error: wrappers.BaseResponse(None, tuple()) - assert 'Invalid status argument' in str(invalid_type_error) + assert "Invalid status argument" in str(invalid_type_error) def test_type_forcing(): def wsgi_application(environ, start_response): - start_response('200 OK', [('Content-Type', 'text/html')]) - return ['Hello World!'] - base_response = wrappers.BaseResponse('Hello World!', content_type='text/html') + start_response("200 OK", [("Content-Type", "text/html")]) + return ["Hello World!"] - class SpecialResponse(wrappers.Response): + base_response = wrappers.BaseResponse("Hello World!", content_type="text/html") + class SpecialResponse(wrappers.Response): def foo(self): return 42 @@ -358,54 +391,63 @@ def test_type_forcing(): response = SpecialResponse.force_type(orig_resp, fake_env) assert response.__class__ is SpecialResponse strict_eq(response.foo(), 42) - strict_eq(response.get_data(), b'Hello World!') - assert response.content_type == 'text/html' + strict_eq(response.get_data(), b"Hello World!") + assert response.content_type == "text/html" # without env, no arbitrary conversion pytest.raises(TypeError, SpecialResponse.force_type, wsgi_application) def test_accept_mixin(): - request = wrappers.Request({ - 'HTTP_ACCEPT': 'text/xml,application/xml,application/xhtml+xml,' - 'text/html;q=0.9,text/plain;q=0.8,image/png,*/*;q=0.5', - 'HTTP_ACCEPT_CHARSET': 'ISO-8859-1,utf-8;q=0.7,*;q=0.7', - 'HTTP_ACCEPT_ENCODING': 'gzip,deflate', - 'HTTP_ACCEPT_LANGUAGE': 'en-us,en;q=0.5' - }) - assert request.accept_mimetypes == MIMEAccept([ - ('text/xml', 1), ('image/png', 1), ('application/xml', 1), - ('application/xhtml+xml', 1), ('text/html', 0.9), - ('text/plain', 0.8), ('*/*', 0.5) - ]) - strict_eq(request.accept_charsets, CharsetAccept([ - ('ISO-8859-1', 1), ('utf-8', 0.7), ('*', 0.7) - ])) - strict_eq(request.accept_encodings, Accept([ - ('gzip', 1), ('deflate', 1)])) - strict_eq(request.accept_languages, LanguageAccept([ - ('en-us', 1), ('en', 0.5)])) - - request = wrappers.Request({'HTTP_ACCEPT': ''}) + request = wrappers.Request( + { + "HTTP_ACCEPT": "text/xml,application/xml,application/xhtml+xml," + "text/html;q=0.9,text/plain;q=0.8,image/png,*/*;q=0.5", + "HTTP_ACCEPT_CHARSET": "ISO-8859-1,utf-8;q=0.7,*;q=0.7", + "HTTP_ACCEPT_ENCODING": "gzip,deflate", + "HTTP_ACCEPT_LANGUAGE": "en-us,en;q=0.5", + } + ) + assert request.accept_mimetypes == MIMEAccept( + [ + ("text/xml", 1), + ("image/png", 1), + ("application/xml", 1), + ("application/xhtml+xml", 1), + ("text/html", 0.9), + ("text/plain", 0.8), + ("*/*", 0.5), + ] + ) + strict_eq( + request.accept_charsets, + CharsetAccept([("ISO-8859-1", 1), ("utf-8", 0.7), ("*", 0.7)]), + ) + strict_eq(request.accept_encodings, Accept([("gzip", 1), ("deflate", 1)])) + strict_eq(request.accept_languages, LanguageAccept([("en-us", 1), ("en", 0.5)])) + + request = wrappers.Request({"HTTP_ACCEPT": ""}) strict_eq(request.accept_mimetypes, MIMEAccept()) def test_etag_request_mixin(): - request = wrappers.Request({ - 'HTTP_CACHE_CONTROL': 'no-store, no-cache', - 'HTTP_IF_MATCH': 'W/"foo", bar, "baz"', - 'HTTP_IF_NONE_MATCH': 'W/"foo", bar, "baz"', - 'HTTP_IF_MODIFIED_SINCE': 'Tue, 22 Jan 2008 11:18:44 GMT', - 'HTTP_IF_UNMODIFIED_SINCE': 'Tue, 22 Jan 2008 11:18:44 GMT' - }) + request = wrappers.Request( + { + "HTTP_CACHE_CONTROL": "no-store, no-cache", + "HTTP_IF_MATCH": 'W/"foo", bar, "baz"', + "HTTP_IF_NONE_MATCH": 'W/"foo", bar, "baz"', + "HTTP_IF_MODIFIED_SINCE": "Tue, 22 Jan 2008 11:18:44 GMT", + "HTTP_IF_UNMODIFIED_SINCE": "Tue, 22 Jan 2008 11:18:44 GMT", + } + ) assert request.cache_control.no_store assert request.cache_control.no_cache for etags in request.if_match, request.if_none_match: - assert etags('bar') + assert etags("bar") assert etags.contains_raw('W/"foo"') - assert etags.contains_weak('foo') - assert not etags.contains('foo') + assert etags.contains_weak("foo") + assert not etags.contains("foo") assert request.if_modified_since == datetime(2008, 1, 22, 11, 18, 44) assert request.if_unmodified_since == datetime(2008, 1, 22, 11, 18, 44) @@ -413,60 +455,170 @@ def test_etag_request_mixin(): def test_user_agent_mixin(): user_agents = [ - ('Mozilla/5.0 (Macintosh; U; Intel Mac OS X; en-US; rv:1.8.1.11) ' - 'Gecko/20071127 Firefox/2.0.0.11', 'firefox', 'macos', '2.0.0.11', - 'en-US'), - ('Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; de-DE) Opera 8.54', - 'opera', 'windows', '8.54', 'de-DE'), - ('Mozilla/5.0 (iPhone; U; CPU like Mac OS X; en) AppleWebKit/420 ' - '(KHTML, like Gecko) Version/3.0 Mobile/1A543a Safari/419.3', - 'safari', 'iphone', '3.0', 'en'), - ('Bot Googlebot/2.1 ( http://www.googlebot.com/bot.html)', - 'google', None, '2.1', None), - ('Mozilla/5.0 (X11; CrOS armv7l 3701.81.0) AppleWebKit/537.31 ' - '(KHTML, like Gecko) Chrome/26.0.1410.57 Safari/537.31', - 'chrome', 'chromeos', '26.0.1410.57', None), - ('Mozilla/5.0 (Windows NT 6.3; Trident/7.0; .NET4.0E; rv:11.0) like Gecko', - 'msie', 'windows', '11.0', None), - ('Mozilla/5.0 (SymbianOS/9.3; Series60/3.2 NokiaE5-00/101.003; ' - 'Profile/MIDP-2.1 Configuration/CLDC-1.1 ) AppleWebKit/533.4 (KHTML, like Gecko) ' - 'NokiaBrowser/7.3.1.35 Mobile Safari/533.4 3gpp-gba', - 'safari', 'symbian', '533.4', None), - ('Mozilla/5.0 (X11; OpenBSD amd64; rv:45.0) Gecko/20100101 Firefox/45.0', - 'firefox', 'openbsd', '45.0', None), - ('Mozilla/5.0 (X11; NetBSD amd64; rv:45.0) Gecko/20100101 Firefox/45.0', - 'firefox', 'netbsd', '45.0', None), - ('Mozilla/5.0 (X11; FreeBSD amd64) AppleWebKit/537.36 (KHTML, like Gecko) ' - 'Chrome/48.0.2564.103 Safari/537.36', - 'chrome', 'freebsd', '48.0.2564.103', None), - ('Mozilla/5.0 (X11; FreeBSD amd64; rv:45.0) Gecko/20100101 Firefox/45.0', - 'firefox', 'freebsd', '45.0', None), - ('Mozilla/5.0 (X11; U; NetBSD amd64; en-US; rv:) Gecko/20150921 SeaMonkey/1.1.18', - 'seamonkey', 'netbsd', '1.1.18', 'en-US'), - ('Mozilla/5.0 (Windows; U; Windows NT 6.2; WOW64; rv:1.8.0.7) ' - 'Gecko/20110321 MultiZilla/4.33.2.6a SeaMonkey/8.6.55', - 'seamonkey', 'windows', '8.6.55', None), - ('Mozilla/5.0 (X11; Linux x86_64; rv:12.0) Gecko/20120427 Firefox/12.0 SeaMonkey/2.9', - 'seamonkey', 'linux', '2.9', None), - ('Mozilla/5.0 (compatible; Baiduspider/2.0; +http://www.baidu.com/search/spider.html)', - 'baidu', None, '2.0', None), - ('Mozilla/5.0 (X11; SunOS i86pc; rv:38.0) Gecko/20100101 Firefox/38.0', - 'firefox', 'solaris', '38.0', None), - ('Mozilla/5.0 (X11; Linux x86_64; rv:38.0) Gecko/20100101 Firefox/38.0 Iceweasel/38.7.1', - 'firefox', 'linux', '38.0', None), - ('Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) ' - 'Chrome/50.0.2661.75 Safari/537.36', - 'chrome', 'windows', '50.0.2661.75', None), - ('Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)', - 'bing', None, '2.0', None), - ('Mozilla/5.0 (X11; DragonFly x86_64) AppleWebKit/537.36 (KHTML, like Gecko) ' - 'Chrome/47.0.2526.106 Safari/537.36', 'chrome', 'dragonflybsd', '47.0.2526.106', None), - ('Mozilla/5.0 (X11; U; DragonFly i386; de; rv:1.9.1) Gecko/20090720 Firefox/3.5.1', - 'firefox', 'dragonflybsd', '3.5.1', 'de') - + ( + "Mozilla/5.0 (Macintosh; U; Intel Mac OS X; en-US; rv:1.8.1.11) " + "Gecko/20071127 Firefox/2.0.0.11", + "firefox", + "macos", + "2.0.0.11", + "en-US", + ), + ( + "Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; de-DE) Opera 8.54", + "opera", + "windows", + "8.54", + "de-DE", + ), + ( + "Mozilla/5.0 (iPhone; U; CPU like Mac OS X; en) AppleWebKit/420 " + "(KHTML, like Gecko) Version/3.0 Mobile/1A543a Safari/419.3", + "safari", + "iphone", + "3.0", + "en", + ), + ( + "Bot Googlebot/2.1 ( http://www.googlebot.com/bot.html)", + "google", + None, + "2.1", + None, + ), + ( + "Mozilla/5.0 (X11; CrOS armv7l 3701.81.0) AppleWebKit/537.31 " + "(KHTML, like Gecko) Chrome/26.0.1410.57 Safari/537.31", + "chrome", + "chromeos", + "26.0.1410.57", + None, + ), + ( + "Mozilla/5.0 (Windows NT 6.3; Trident/7.0; .NET4.0E; rv:11.0) like Gecko", + "msie", + "windows", + "11.0", + None, + ), + ( + "Mozilla/5.0 (SymbianOS/9.3; Series60/3.2 NokiaE5-00/101.003; " + "Profile/MIDP-2.1 Configuration/CLDC-1.1 ) AppleWebKit/533.4 " + "(KHTML, like Gecko) NokiaBrowser/7.3.1.35 Mobile Safari/533.4 3gpp-gba", + "safari", + "symbian", + "533.4", + None, + ), + ( + "Mozilla/5.0 (X11; OpenBSD amd64; rv:45.0) Gecko/20100101 Firefox/45.0", + "firefox", + "openbsd", + "45.0", + None, + ), + ( + "Mozilla/5.0 (X11; NetBSD amd64; rv:45.0) Gecko/20100101 Firefox/45.0", + "firefox", + "netbsd", + "45.0", + None, + ), + ( + "Mozilla/5.0 (X11; FreeBSD amd64) AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/48.0.2564.103 Safari/537.36", + "chrome", + "freebsd", + "48.0.2564.103", + None, + ), + ( + "Mozilla/5.0 (X11; FreeBSD amd64; rv:45.0) Gecko/20100101 Firefox/45.0", + "firefox", + "freebsd", + "45.0", + None, + ), + ( + "Mozilla/5.0 (X11; U; NetBSD amd64; en-US; rv:) Gecko/20150921 " + "SeaMonkey/1.1.18", + "seamonkey", + "netbsd", + "1.1.18", + "en-US", + ), + ( + "Mozilla/5.0 (Windows; U; Windows NT 6.2; WOW64; rv:1.8.0.7) " + "Gecko/20110321 MultiZilla/4.33.2.6a SeaMonkey/8.6.55", + "seamonkey", + "windows", + "8.6.55", + None, + ), + ( + "Mozilla/5.0 (X11; Linux x86_64; rv:12.0) Gecko/20120427 Firefox/12.0 " + "SeaMonkey/2.9", + "seamonkey", + "linux", + "2.9", + None, + ), + ( + "Mozilla/5.0 (compatible; Baiduspider/2.0; " + "+http://www.baidu.com/search/spider.html)", + "baidu", + None, + "2.0", + None, + ), + ( + "Mozilla/5.0 (X11; SunOS i86pc; rv:38.0) Gecko/20100101 Firefox/38.0", + "firefox", + "solaris", + "38.0", + None, + ), + ( + "Mozilla/5.0 (X11; Linux x86_64; rv:38.0) Gecko/20100101 Firefox/38.0 " + "Iceweasel/38.7.1", + "firefox", + "linux", + "38.0", + None, + ), + ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 " + "(KHTML, like Gecko) Chrome/50.0.2661.75 Safari/537.36", + "chrome", + "windows", + "50.0.2661.75", + None, + ), + ( + "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)", + "bing", + None, + "2.0", + None, + ), + ( + "Mozilla/5.0 (X11; DragonFly x86_64) AppleWebKit/537.36 " + "(KHTML, like Gecko) Chrome/47.0.2526.106 Safari/537.36", + "chrome", + "dragonflybsd", + "47.0.2526.106", + None, + ), + ( + "Mozilla/5.0 (X11; U; DragonFly i386; de; rv:1.9.1) " + "Gecko/20090720 Firefox/3.5.1", + "firefox", + "dragonflybsd", + "3.5.1", + "de", + ), ] for ua, browser, platform, version, lang in user_agents: - request = wrappers.Request({'HTTP_USER_AGENT': ua}) + request = wrappers.Request({"HTTP_USER_AGENT": ua}) strict_eq(request.user_agent.browser, browser) strict_eq(request.user_agent.platform, platform) strict_eq(request.user_agent.version, version) @@ -475,13 +627,12 @@ def test_user_agent_mixin(): strict_eq(request.user_agent.to_header(), ua) strict_eq(str(request.user_agent), ua) - request = wrappers.Request({'HTTP_USER_AGENT': 'foo'}) + request = wrappers.Request({"HTTP_USER_AGENT": "foo"}) assert not request.user_agent def test_stream_wrapping(): class LowercasingStream(object): - def __init__(self, stream): self._stream = stream @@ -491,121 +642,121 @@ def test_stream_wrapping(): def readline(self, size=-1): return self._stream.readline(size).lower() - data = b'foo=Hello+World' + data = b"foo=Hello+World" req = wrappers.Request.from_values( - '/', method='POST', data=data, - content_type='application/x-www-form-urlencoded') + "/", method="POST", data=data, content_type="application/x-www-form-urlencoded" + ) req.stream = LowercasingStream(req.stream) - assert req.form['foo'] == 'hello world' + assert req.form["foo"] == "hello world" def test_data_descriptor_triggers_parsing(): - data = b'foo=Hello+World' + data = b"foo=Hello+World" req = wrappers.Request.from_values( - '/', method='POST', data=data, - content_type='application/x-www-form-urlencoded') + "/", method="POST", data=data, content_type="application/x-www-form-urlencoded" + ) - assert req.data == b'' - assert req.form['foo'] == u'Hello World' + assert req.data == b"" + assert req.form["foo"] == u"Hello World" def test_get_data_method_parsing_caching_behavior(): - data = b'foo=Hello+World' + data = b"foo=Hello+World" req = wrappers.Request.from_values( - '/', method='POST', data=data, - content_type='application/x-www-form-urlencoded') + "/", method="POST", data=data, content_type="application/x-www-form-urlencoded" + ) # get_data() caches, so form stays available assert req.get_data() == data - assert req.form['foo'] == u'Hello World' + assert req.form["foo"] == u"Hello World" assert req.get_data() == data # here we access the form data first, caching is bypassed req = wrappers.Request.from_values( - '/', method='POST', data=data, - content_type='application/x-www-form-urlencoded') - assert req.form['foo'] == u'Hello World' - assert req.get_data() == b'' + "/", method="POST", data=data, content_type="application/x-www-form-urlencoded" + ) + assert req.form["foo"] == u"Hello World" + assert req.get_data() == b"" # Another case is uncached get data which trashes everything req = wrappers.Request.from_values( - '/', method='POST', data=data, - content_type='application/x-www-form-urlencoded') + "/", method="POST", data=data, content_type="application/x-www-form-urlencoded" + ) assert req.get_data(cache=False) == data - assert req.get_data(cache=False) == b'' + assert req.get_data(cache=False) == b"" assert req.form == {} # Or we can implicitly start the form parser which is similar to # the old .data behavior req = wrappers.Request.from_values( - '/', method='POST', data=data, - content_type='application/x-www-form-urlencoded') - assert req.get_data(parse_form_data=True) == b'' - assert req.form['foo'] == u'Hello World' + "/", method="POST", data=data, content_type="application/x-www-form-urlencoded" + ) + assert req.get_data(parse_form_data=True) == b"" + assert req.form["foo"] == u"Hello World" def test_etag_response_mixin(): - response = wrappers.Response('Hello World') + response = wrappers.Response("Hello World") assert response.get_etag() == (None, None) response.add_etag() - assert response.get_etag() == ('b10a8db164e0754105b7a99be72e3fe5', False) + assert response.get_etag() == ("b10a8db164e0754105b7a99be72e3fe5", False) assert not response.cache_control response.cache_control.must_revalidate = True response.cache_control.max_age = 60 - response.headers['Content-Length'] = len(response.get_data()) - assert response.headers['Cache-Control'] in ('must-revalidate, max-age=60', - 'max-age=60, must-revalidate') + response.headers["Content-Length"] = len(response.get_data()) + assert response.headers["Cache-Control"] in ( + "must-revalidate, max-age=60", + "max-age=60, must-revalidate", + ) - assert 'date' not in response.headers + assert "date" not in response.headers env = create_environ() - env.update({ - 'REQUEST_METHOD': 'GET', - 'HTTP_IF_NONE_MATCH': response.get_etag()[0] - }) + env.update({"REQUEST_METHOD": "GET", "HTTP_IF_NONE_MATCH": response.get_etag()[0]}) response.make_conditional(env) - assert 'date' in response.headers + assert "date" in response.headers # after the thing is invoked by the server as wsgi application # (we're emulating this here), there must not be any entity # headers left and the status code would have to be 304 resp = wrappers.Response.from_app(response, env) assert resp.status_code == 304 - assert 'content-length' not in resp.headers + assert "content-length" not in resp.headers # make sure date is not overriden - response = wrappers.Response('Hello World') + response = wrappers.Response("Hello World") response.date = 1337 d = response.date response.make_conditional(env) assert response.date == d # make sure content length is only set if missing - response = wrappers.Response('Hello World') + response = wrappers.Response("Hello World") response.content_length = 999 response.make_conditional(env) assert response.content_length == 999 def test_etag_response_412(): - response = wrappers.Response('Hello World') + response = wrappers.Response("Hello World") assert response.get_etag() == (None, None) response.add_etag() - assert response.get_etag() == ('b10a8db164e0754105b7a99be72e3fe5', False) + assert response.get_etag() == ("b10a8db164e0754105b7a99be72e3fe5", False) assert not response.cache_control response.cache_control.must_revalidate = True response.cache_control.max_age = 60 - response.headers['Content-Length'] = len(response.get_data()) - assert response.headers['Cache-Control'] in ('must-revalidate, max-age=60', - 'max-age=60, must-revalidate') + response.headers["Content-Length"] = len(response.get_data()) + assert response.headers["Cache-Control"] in ( + "must-revalidate, max-age=60", + "max-age=60, must-revalidate", + ) - assert 'date' not in response.headers + assert "date" not in response.headers env = create_environ() - env.update({ - 'REQUEST_METHOD': 'GET', - 'HTTP_IF_MATCH': response.get_etag()[0] + "xyz" - }) + env.update( + {"REQUEST_METHOD": "GET", "HTTP_IF_MATCH": response.get_etag()[0] + "xyz"} + ) response.make_conditional(env) - assert 'date' in response.headers + assert "date" in response.headers # after the thing is invoked by the server as wsgi application # (we're emulating this here), there must not be any entity @@ -613,17 +764,17 @@ def test_etag_response_412(): resp = wrappers.Response.from_app(response, env) assert resp.status_code == 412 # Make sure there is a body still - assert resp.data != b'' + assert resp.data != b"" # make sure date is not overriden - response = wrappers.Response('Hello World') + response = wrappers.Response("Hello World") response.date = 1337 d = response.date response.make_conditional(env) assert response.date == d # make sure content length is only set if missing - response = wrappers.Response('Hello World') + response = wrappers.Response("Hello World") response.content_length = 999 response.make_conditional(env) assert response.content_length == 999 @@ -631,77 +782,78 @@ def test_etag_response_412(): def test_range_request_basic(): env = create_environ() - response = wrappers.Response('Hello World') - env['HTTP_RANGE'] = 'bytes=0-4' + response = wrappers.Response("Hello World") + env["HTTP_RANGE"] = "bytes=0-4" response.make_conditional(env, accept_ranges=True, complete_length=11) assert response.status_code == 206 - assert response.headers['Accept-Ranges'] == 'bytes' - assert response.headers['Content-Range'] == 'bytes 0-4/11' - assert response.headers['Content-Length'] == '5' - assert response.data == b'Hello' + assert response.headers["Accept-Ranges"] == "bytes" + assert response.headers["Content-Range"] == "bytes 0-4/11" + assert response.headers["Content-Length"] == "5" + assert response.data == b"Hello" def test_range_request_out_of_bound(): env = create_environ() - response = wrappers.Response('Hello World') - env['HTTP_RANGE'] = 'bytes=6-666' + response = wrappers.Response("Hello World") + env["HTTP_RANGE"] = "bytes=6-666" response.make_conditional(env, accept_ranges=True, complete_length=11) assert response.status_code == 206 - assert response.headers['Accept-Ranges'] == 'bytes' - assert response.headers['Content-Range'] == 'bytes 6-10/11' - assert response.headers['Content-Length'] == '5' - assert response.data == b'World' + assert response.headers["Accept-Ranges"] == "bytes" + assert response.headers["Content-Range"] == "bytes 6-10/11" + assert response.headers["Content-Length"] == "5" + assert response.data == b"World" def test_range_request_with_file(): env = create_environ() - resources = os.path.join(os.path.dirname(__file__), 'res') - fname = os.path.join(resources, 'test.txt') - with open(fname, 'rb') as f: + resources = os.path.join(os.path.dirname(__file__), "res") + fname = os.path.join(resources, "test.txt") + with open(fname, "rb") as f: fcontent = f.read() - with open(fname, 'rb') as f: + with open(fname, "rb") as f: response = wrappers.Response(wrap_file(env, f)) - env['HTTP_RANGE'] = 'bytes=0-0' - response.make_conditional(env, accept_ranges=True, complete_length=len(fcontent)) + env["HTTP_RANGE"] = "bytes=0-0" + response.make_conditional( + env, accept_ranges=True, complete_length=len(fcontent) + ) assert response.status_code == 206 - assert response.headers['Accept-Ranges'] == 'bytes' - assert response.headers['Content-Range'] == 'bytes 0-0/%d' % len(fcontent) - assert response.headers['Content-Length'] == '1' + assert response.headers["Accept-Ranges"] == "bytes" + assert response.headers["Content-Range"] == "bytes 0-0/%d" % len(fcontent) + assert response.headers["Content-Length"] == "1" assert response.data == fcontent[:1] def test_range_request_with_complete_file(): env = create_environ() - resources = os.path.join(os.path.dirname(__file__), 'res') - fname = os.path.join(resources, 'test.txt') - with open(fname, 'rb') as f: + resources = os.path.join(os.path.dirname(__file__), "res") + fname = os.path.join(resources, "test.txt") + with open(fname, "rb") as f: fcontent = f.read() - with open(fname, 'rb') as f: + with open(fname, "rb") as f: fsize = os.path.getsize(fname) response = wrappers.Response(wrap_file(env, f)) - env['HTTP_RANGE'] = 'bytes=0-%d' % (fsize - 1) - response.make_conditional(env, accept_ranges=True, - complete_length=fsize) + env["HTTP_RANGE"] = "bytes=0-%d" % (fsize - 1) + response.make_conditional(env, accept_ranges=True, complete_length=fsize) assert response.status_code == 200 - assert response.headers['Accept-Ranges'] == 'bytes' - assert 'Content-Range' not in response.headers - assert response.headers['Content-Length'] == str(fsize) + assert response.headers["Accept-Ranges"] == "bytes" + assert "Content-Range" not in response.headers + assert response.headers["Content-Length"] == str(fsize) assert response.data == fcontent def test_range_request_without_complete_length(): env = create_environ() - response = wrappers.Response('Hello World') - env['HTTP_RANGE'] = 'bytes=-' + response = wrappers.Response("Hello World") + env["HTTP_RANGE"] = "bytes=-" response.make_conditional(env, accept_ranges=True, complete_length=None) assert response.status_code == 200 - assert response.data == b'Hello World' + assert response.data == b"Hello World" def test_invalid_range_request(): env = create_environ() - response = wrappers.Response('Hello World') - env['HTTP_RANGE'] = 'bytes=-' + response = wrappers.Response("Hello World") + env["HTTP_RANGE"] = "bytes=-" with pytest.raises(RequestedRangeNotSatisfiable): response.make_conditional(env, accept_ranges=True, complete_length=11) @@ -713,69 +865,68 @@ def test_etag_response_mixin_freezing(): class WithoutFreeze(wrappers.BaseResponse, wrappers.ETagResponseMixin): pass - response = WithFreeze('Hello World') + response = WithFreeze("Hello World") response.freeze() - strict_eq(response.get_etag(), - (text_type(generate_etag(b'Hello World')), False)) - response = WithoutFreeze('Hello World') + strict_eq(response.get_etag(), (text_type(generate_etag(b"Hello World")), False)) + response = WithoutFreeze("Hello World") response.freeze() assert response.get_etag() == (None, None) - response = wrappers.Response('Hello World') + response = wrappers.Response("Hello World") response.freeze() assert response.get_etag() == (None, None) def test_authenticate_mixin(): resp = wrappers.Response() - resp.www_authenticate.type = 'basic' - resp.www_authenticate.realm = 'Testing' - strict_eq(resp.headers['WWW-Authenticate'], u'Basic realm="Testing"') + resp.www_authenticate.type = "basic" + resp.www_authenticate.realm = "Testing" + strict_eq(resp.headers["WWW-Authenticate"], u'Basic realm="Testing"') resp.www_authenticate.realm = None resp.www_authenticate.type = None - assert 'WWW-Authenticate' not in resp.headers + assert "WWW-Authenticate" not in resp.headers def test_authenticate_mixin_quoted_qop(): # Example taken from https://github.com/pallets/werkzeug/issues/633 resp = wrappers.Response() - resp.www_authenticate.set_digest('REALM', 'NONCE', qop=("auth", "auth-int")) + resp.www_authenticate.set_digest("REALM", "NONCE", qop=("auth", "auth-int")) - actual = set((resp.headers['WWW-Authenticate'] + ',').split()) + actual = set((resp.headers["WWW-Authenticate"] + ",").split()) expected = set('Digest nonce="NONCE", realm="REALM", qop="auth, auth-int",'.split()) assert actual == expected - resp.www_authenticate.set_digest('REALM', 'NONCE', qop=("auth",)) + resp.www_authenticate.set_digest("REALM", "NONCE", qop=("auth",)) - actual = set((resp.headers['WWW-Authenticate'] + ',').split()) + actual = set((resp.headers["WWW-Authenticate"] + ",").split()) expected = set('Digest nonce="NONCE", realm="REALM", qop="auth",'.split()) assert actual == expected def test_response_stream_mixin(): response = wrappers.Response() - response.stream.write('Hello ') - response.stream.write('World!') - assert response.response == ['Hello ', 'World!'] - assert response.get_data() == b'Hello World!' + response.stream.write("Hello ") + response.stream.write("World!") + assert response.response == ["Hello ", "World!"] + assert response.get_data() == b"Hello World!" def test_common_response_descriptors_mixin(): response = wrappers.Response() - response.mimetype = 'text/html' - assert response.mimetype == 'text/html' - assert response.content_type == 'text/html; charset=utf-8' - assert response.mimetype_params == {'charset': 'utf-8'} - response.mimetype_params['x-foo'] = 'yep' - del response.mimetype_params['charset'] - assert response.content_type == 'text/html; x-foo=yep' + response.mimetype = "text/html" + assert response.mimetype == "text/html" + assert response.content_type == "text/html; charset=utf-8" + assert response.mimetype_params == {"charset": "utf-8"} + response.mimetype_params["x-foo"] = "yep" + del response.mimetype_params["charset"] + assert response.content_type == "text/html; x-foo=yep" now = datetime.utcnow().replace(microsecond=0) assert response.content_length is None - response.content_length = '42' + response.content_length = "42" assert response.content_length == 42 - for attr in 'date', 'expires': + for attr in "date", "expires": assert getattr(response, attr) is None setattr(response, attr, now) assert getattr(response, attr) == now @@ -792,95 +943,97 @@ def test_common_response_descriptors_mixin(): assert response.retry_after == now assert not response.vary - response.vary.add('Cookie') - response.vary.add('Content-Language') - assert 'cookie' in response.vary - assert response.vary.to_header() == 'Cookie, Content-Language' - response.headers['Vary'] = 'Content-Encoding' - assert response.vary.as_set() == set(['content-encoding']) + response.vary.add("Cookie") + response.vary.add("Content-Language") + assert "cookie" in response.vary + assert response.vary.to_header() == "Cookie, Content-Language" + response.headers["Vary"] = "Content-Encoding" + assert response.vary.as_set() == {"content-encoding"} - response.allow.update(['GET', 'POST']) - assert response.headers['Allow'] == 'GET, POST' + response.allow.update(["GET", "POST"]) + assert response.headers["Allow"] == "GET, POST" - response.content_language.add('en-US') - response.content_language.add('fr') - assert response.headers['Content-Language'] == 'en-US, fr' + response.content_language.add("en-US") + response.content_language.add("fr") + assert response.headers["Content-Language"] == "en-US, fr" def test_common_request_descriptors_mixin(): request = wrappers.Request.from_values( - content_type='text/html; charset=utf-8', - content_length='23', + content_type="text/html; charset=utf-8", + content_length="23", headers={ - 'Referer': 'http://www.example.com/', - 'Date': 'Sat, 28 Feb 2009 19:04:35 GMT', - 'Max-Forwards': '10', - 'Pragma': 'no-cache', - 'Content-Encoding': 'gzip', - 'Content-MD5': '9a3bc6dbc47a70db25b84c6e5867a072' - } + "Referer": "http://www.example.com/", + "Date": "Sat, 28 Feb 2009 19:04:35 GMT", + "Max-Forwards": "10", + "Pragma": "no-cache", + "Content-Encoding": "gzip", + "Content-MD5": "9a3bc6dbc47a70db25b84c6e5867a072", + }, ) - assert request.content_type == 'text/html; charset=utf-8' - assert request.mimetype == 'text/html' - assert request.mimetype_params == {'charset': 'utf-8'} + assert request.content_type == "text/html; charset=utf-8" + assert request.mimetype == "text/html" + assert request.mimetype_params == {"charset": "utf-8"} assert request.content_length == 23 - assert request.referrer == 'http://www.example.com/' + assert request.referrer == "http://www.example.com/" assert request.date == datetime(2009, 2, 28, 19, 4, 35) assert request.max_forwards == 10 - assert 'no-cache' in request.pragma - assert request.content_encoding == 'gzip' - assert request.content_md5 == '9a3bc6dbc47a70db25b84c6e5867a072' + assert "no-cache" in request.pragma + assert request.content_encoding == "gzip" + assert request.content_md5 == "9a3bc6dbc47a70db25b84c6e5867a072" def test_request_mimetype_always_lowercase(): - request = wrappers.Request.from_values(content_type='APPLICATION/JSON') - assert request.mimetype == 'application/json' + request = wrappers.Request.from_values(content_type="APPLICATION/JSON") + assert request.mimetype == "application/json" def test_shallow_mode(): - request = wrappers.Request({'QUERY_STRING': 'foo=bar'}, shallow=True) - assert request.args['foo'] == 'bar' - pytest.raises(RuntimeError, lambda: request.form['foo']) + request = wrappers.Request({"QUERY_STRING": "foo=bar"}, shallow=True) + assert request.args["foo"] == "bar" + pytest.raises(RuntimeError, lambda: request.form["foo"]) def test_form_parsing_failed(): - data = b'--blah\r\n' + data = b"--blah\r\n" request = wrappers.Request.from_values( input_stream=BytesIO(data), content_length=len(data), - content_type='multipart/form-data; boundary=foo', - method='POST' + content_type="multipart/form-data; boundary=foo", + method="POST", ) assert not request.files assert not request.form # Bad Content-Type - data = b'test' + data = b"test" request = wrappers.Request.from_values( input_stream=BytesIO(data), content_length=len(data), - content_type=', ', - method='POST' + content_type=", ", + method="POST", ) assert not request.form def test_file_closing(): - data = (b'--foo\r\n' - b'Content-Disposition: form-data; name="foo"; filename="foo.txt"\r\n' - b'Content-Type: text/plain; charset=utf-8\r\n\r\n' - b'file contents, just the contents\r\n' - b'--foo--') + data = ( + b"--foo\r\n" + b'Content-Disposition: form-data; name="foo"; filename="foo.txt"\r\n' + b"Content-Type: text/plain; charset=utf-8\r\n\r\n" + b"file contents, just the contents\r\n" + b"--foo--" + ) req = wrappers.Request.from_values( input_stream=BytesIO(data), content_length=len(data), - content_type='multipart/form-data; boundary=foo', - method='POST' + content_type="multipart/form-data; boundary=foo", + method="POST", ) - foo = req.files['foo'] - assert foo.mimetype == 'text/plain' - assert foo.filename == 'foo.txt' + foo = req.files["foo"] + assert foo.mimetype == "text/plain" + assert foo.filename == "foo.txt" assert foo.closed is False req.close() @@ -888,29 +1041,31 @@ def test_file_closing(): def test_file_closing_with(): - data = (b'--foo\r\n' - b'Content-Disposition: form-data; name="foo"; filename="foo.txt"\r\n' - b'Content-Type: text/plain; charset=utf-8\r\n\r\n' - b'file contents, just the contents\r\n' - b'--foo--') + data = ( + b"--foo\r\n" + b'Content-Disposition: form-data; name="foo"; filename="foo.txt"\r\n' + b"Content-Type: text/plain; charset=utf-8\r\n\r\n" + b"file contents, just the contents\r\n" + b"--foo--" + ) req = wrappers.Request.from_values( input_stream=BytesIO(data), content_length=len(data), - content_type='multipart/form-data; boundary=foo', - method='POST' + content_type="multipart/form-data; boundary=foo", + method="POST", ) with req: - foo = req.files['foo'] - assert foo.mimetype == 'text/plain' - assert foo.filename == 'foo.txt' + foo = req.files["foo"] + assert foo.mimetype == "text/plain" + assert foo.filename == "foo.txt" assert foo.closed is True def test_url_charset_reflection(): req = wrappers.Request.from_values() - req.charset = 'utf-7' - assert req.url_charset == 'utf-7' + req.charset = "utf-7" + assert req.url_charset == "utf-7" def test_response_streamed(): @@ -924,6 +1079,7 @@ def test_response_streamed(): def gen(): if 0: yield None + r = wrappers.Response(gen()) assert r.is_streamed @@ -934,58 +1090,61 @@ def test_response_iter_wrapping(): yield item.upper() def generator(): - yield 'foo' - yield 'bar' + yield "foo" + yield "bar" + req = wrappers.Request.from_values() resp = wrappers.Response(generator()) - del resp.headers['Content-Length'] + del resp.headers["Content-Length"] resp.response = uppercasing(resp.iter_encoded()) actual_resp = wrappers.Response.from_app(resp, req.environ, buffered=True) - assert actual_resp.get_data() == b'FOOBAR' + assert actual_resp.get_data() == b"FOOBAR" def test_response_freeze(): def generate(): yield "foo" yield "bar" + resp = wrappers.Response(generate()) resp.freeze() - assert resp.response == [b'foo', b'bar'] - assert resp.headers['content-length'] == '6' + assert resp.response == [b"foo", b"bar"] + assert resp.headers["content-length"] == "6" def test_response_content_length_uses_encode(): - r = wrappers.Response(u'你好') + r = wrappers.Response(u"你好") assert r.calculate_content_length() == 6 def test_other_method_payload(): - data = b'Hello World' - req = wrappers.Request.from_values(input_stream=BytesIO(data), - content_length=len(data), - content_type='text/plain', - method='WHAT_THE_FUCK') + data = b"Hello World" + req = wrappers.Request.from_values( + input_stream=BytesIO(data), + content_length=len(data), + content_type="text/plain", + method="WHAT_THE_FUCK", + ) assert req.get_data() == data assert isinstance(req.stream, LimitedStream) def test_urlfication(): resp = wrappers.Response() - resp.headers['Location'] = u'http://üser:pässword@☃.net/påth' - resp.headers['Content-Location'] = u'http://☃.net/' + resp.headers["Location"] = u"http://üser:pässword@☃.net/påth" + resp.headers["Content-Location"] = u"http://☃.net/" headers = resp.get_wsgi_headers(create_environ()) - assert headers['location'] == \ - 'http://%C3%BCser:p%C3%A4ssword@xn--n3h.net/p%C3%A5th' - assert headers['content-location'] == 'http://xn--n3h.net/' + assert headers["location"] == "http://%C3%BCser:p%C3%A4ssword@xn--n3h.net/p%C3%A5th" + assert headers["content-location"] == "http://xn--n3h.net/" def test_new_response_iterator_behavior(): req = wrappers.Request.from_values() - resp = wrappers.Response(u'Hello Wörld!') + resp = wrappers.Response(u"Hello Wörld!") def get_content_length(resp): headers = resp.get_wsgi_headers(req.environ) - return headers.get('content-length', type=int) + return headers.get("content-length", type=int) def generate_items(): yield "Hello " @@ -993,16 +1152,16 @@ def test_new_response_iterator_behavior(): # werkzeug encodes when set to `data` now, which happens # if a string is passed to the response object. - assert resp.response == [u'Hello Wörld!'.encode('utf-8')] - assert resp.get_data() == u'Hello Wörld!'.encode('utf-8') + assert resp.response == [u"Hello Wörld!".encode("utf-8")] + assert resp.get_data() == u"Hello Wörld!".encode("utf-8") assert get_content_length(resp) == 13 assert not resp.is_streamed assert resp.is_sequence # try the same for manual assignment - resp.set_data(u'Wörd') - assert resp.response == [u'Wörd'.encode('utf-8')] - assert resp.get_data() == u'Wörd'.encode('utf-8') + resp.set_data(u"Wörd") + assert resp.response == [u"Wörd".encode("utf-8")] + assert resp.get_data() == u"Wörd".encode("utf-8") assert get_content_length(resp) == 5 assert not resp.is_streamed assert resp.is_sequence @@ -1011,8 +1170,8 @@ def test_new_response_iterator_behavior(): resp.response = generate_items() assert resp.is_streamed assert not resp.is_sequence - assert resp.get_data() == u'Hello Wörld!'.encode('utf-8') - assert resp.response == [b'Hello ', u'Wörld!'.encode('utf-8')] + assert resp.get_data() == u"Hello Wörld!".encode("utf-8") + assert resp.response == [b"Hello ", u"Wörld!".encode("utf-8")] assert not resp.is_streamed assert resp.is_sequence @@ -1023,8 +1182,8 @@ def test_new_response_iterator_behavior(): assert not resp.is_sequence pytest.raises(RuntimeError, lambda: resp.get_data()) resp.make_sequence() - assert resp.get_data() == u'Hello Wörld!'.encode('utf-8') - assert resp.response == [b'Hello ', u'Wörld!'.encode('utf-8')] + assert resp.get_data() == u"Hello Wörld!".encode("utf-8") + assert resp.response == [b"Hello ", u"Wörld!".encode("utf-8")] assert not resp.is_streamed assert resp.is_sequence @@ -1033,25 +1192,25 @@ def test_new_response_iterator_behavior(): resp.implicit_sequence_conversion = val resp.response = ("foo", "bar") assert resp.is_sequence - resp.stream.write('baz') - assert resp.response == ['foo', 'bar', 'baz'] + resp.stream.write("baz") + assert resp.response == ["foo", "bar", "baz"] def test_form_data_ordering(): class MyRequest(wrappers.Request): parameter_storage_class = ImmutableOrderedMultiDict - req = MyRequest.from_values('/?foo=1&bar=0&foo=3') - assert list(req.args) == ['foo', 'bar'] + req = MyRequest.from_values("/?foo=1&bar=0&foo=3") + assert list(req.args) == ["foo", "bar"] assert list(req.args.items(multi=True)) == [ - ('foo', '1'), - ('bar', '0'), - ('foo', '3') + ("foo", "1"), + ("bar", "0"), + ("foo", "3"), ] assert isinstance(req.args, ImmutableOrderedMultiDict) assert isinstance(req.values, CombinedMultiDict) - assert req.values['foo'] == '1' - assert req.values.getlist('foo') == ['1', '3'] + assert req.values["foo"] == "1" + assert req.values.getlist("foo") == ["1", "3"] def test_storage_classes(): @@ -1059,22 +1218,19 @@ def test_storage_classes(): dict_storage_class = dict list_storage_class = list parameter_storage_class = dict - req = MyRequest.from_values('/?foo=baz', headers={ - 'Cookie': 'foo=bar' - }) + + req = MyRequest.from_values("/?foo=baz", headers={"Cookie": "foo=bar"}) assert type(req.cookies) is dict - assert req.cookies == {'foo': 'bar'} + assert req.cookies == {"foo": "bar"} assert type(req.access_route) is list assert type(req.args) is dict assert type(req.values) is CombinedMultiDict - assert req.values['foo'] == u'baz' + assert req.values["foo"] == u"baz" - req = wrappers.Request.from_values(headers={ - 'Cookie': 'foo=bar' - }) + req = wrappers.Request.from_values(headers={"Cookie": "foo=bar"}) assert type(req.cookies) is ImmutableTypeConversionDict - assert req.cookies == {'foo': 'bar'} + assert req.cookies == {"foo": "bar"} assert type(req.access_route) is ImmutableList MyRequest.list_storage_class = tuple @@ -1089,91 +1245,93 @@ def test_response_headers_passthrough(): def test_response_304_no_content_length(): - resp = wrappers.Response('Test', status=304) + resp = wrappers.Response("Test", status=304) env = create_environ() - assert 'content-length' not in resp.get_wsgi_headers(env) + assert "content-length" not in resp.get_wsgi_headers(env) def test_ranges(): # basic range stuff req = wrappers.Request.from_values() assert req.range is None - req = wrappers.Request.from_values(headers={'Range': 'bytes=0-499'}) + req = wrappers.Request.from_values(headers={"Range": "bytes=0-499"}) assert req.range.ranges == [(0, 500)] resp = wrappers.Response() resp.content_range = req.range.make_content_range(1000) - assert resp.content_range.units == 'bytes' + assert resp.content_range.units == "bytes" assert resp.content_range.start == 0 assert resp.content_range.stop == 500 assert resp.content_range.length == 1000 - assert resp.headers['Content-Range'] == 'bytes 0-499/1000' + assert resp.headers["Content-Range"] == "bytes 0-499/1000" resp.content_range.unset() - assert 'Content-Range' not in resp.headers + assert "Content-Range" not in resp.headers - resp.headers['Content-Range'] = 'bytes 0-499/1000' - assert resp.content_range.units == 'bytes' + resp.headers["Content-Range"] = "bytes 0-499/1000" + assert resp.content_range.units == "bytes" assert resp.content_range.start == 0 assert resp.content_range.stop == 500 assert resp.content_range.length == 1000 def test_auto_content_length(): - resp = wrappers.Response('Hello World!') + resp = wrappers.Response("Hello World!") assert resp.content_length == 12 - resp = wrappers.Response(['Hello World!']) + resp = wrappers.Response(["Hello World!"]) assert resp.content_length is None - assert resp.get_wsgi_headers({})['Content-Length'] == '12' + assert resp.get_wsgi_headers({})["Content-Length"] == "12" def test_stream_content_length(): resp = wrappers.Response() - resp.stream.writelines(['foo', 'bar', 'baz']) - assert resp.get_wsgi_headers({})['Content-Length'] == '9' + resp.stream.writelines(["foo", "bar", "baz"]) + assert resp.get_wsgi_headers({})["Content-Length"] == "9" resp = wrappers.Response() - resp.make_conditional({'REQUEST_METHOD': 'GET'}) - resp.stream.writelines(['foo', 'bar', 'baz']) - assert resp.get_wsgi_headers({})['Content-Length'] == '9' + resp.make_conditional({"REQUEST_METHOD": "GET"}) + resp.stream.writelines(["foo", "bar", "baz"]) + assert resp.get_wsgi_headers({})["Content-Length"] == "9" - resp = wrappers.Response('foo') - resp.stream.writelines(['bar', 'baz']) - assert resp.get_wsgi_headers({})['Content-Length'] == '9' + resp = wrappers.Response("foo") + resp.stream.writelines(["bar", "baz"]) + assert resp.get_wsgi_headers({})["Content-Length"] == "9" def test_disabled_auto_content_length(): class MyResponse(wrappers.Response): automatically_set_content_length = False - resp = MyResponse('Hello World!') + + resp = MyResponse("Hello World!") assert resp.content_length is None - resp = MyResponse(['Hello World!']) + resp = MyResponse(["Hello World!"]) assert resp.content_length is None - assert 'Content-Length' not in resp.get_wsgi_headers({}) + assert "Content-Length" not in resp.get_wsgi_headers({}) resp = MyResponse() - resp.make_conditional({ - 'REQUEST_METHOD': 'GET' - }) + resp.make_conditional({"REQUEST_METHOD": "GET"}) assert resp.content_length is None - assert 'Content-Length' not in resp.get_wsgi_headers({}) - - -@pytest.mark.parametrize(('auto', 'location', 'expect'), ( - (False, '/test', '/test'), - (True, '/test', 'http://localhost/test'), - (True, 'test', 'http://localhost/a/b/test'), - (True, './test', 'http://localhost/a/b/test'), - (True, '../test', 'http://localhost/a/test'), -)) + assert "Content-Length" not in resp.get_wsgi_headers({}) + + +@pytest.mark.parametrize( + ("auto", "location", "expect"), + ( + (False, "/test", "/test"), + (True, "/test", "http://localhost/test"), + (True, "test", "http://localhost/a/b/test"), + (True, "./test", "http://localhost/a/b/test"), + (True, "../test", "http://localhost/a/test"), + ), +) def test_location_header_autocorrect(monkeypatch, auto, location, expect): - monkeypatch.setattr(wrappers.Response, 'autocorrect_location_header', auto) - env = create_environ('/a/b/c') - resp = wrappers.Response('Hello World!') - resp.headers['Location'] = location - assert resp.get_wsgi_headers(env)['Location'] == expect + monkeypatch.setattr(wrappers.Response, "autocorrect_location_header", auto) + env = create_environ("/a/b/c") + resp = wrappers.Response("Hello World!") + resp.headers["Location"] = location + assert resp.get_wsgi_headers(env)["Location"] == expect def test_204_and_1XX_response_has_no_content_length(): @@ -1181,39 +1339,39 @@ def test_204_and_1XX_response_has_no_content_length(): assert response.content_length is None headers = response.get_wsgi_headers(create_environ()) - assert 'Content-Length' not in headers + assert "Content-Length" not in headers response = wrappers.Response(status=100) assert response.content_length is None headers = response.get_wsgi_headers(create_environ()) - assert 'Content-Length' not in headers + assert "Content-Length" not in headers def test_malformed_204_response_has_no_content_length(): # flask-restful can generate a malformed response when doing `return '', 204` response = wrappers.Response(status=204) - response.set_data(b'test') + response.set_data(b"test") assert response.content_length == 4 env = create_environ() app_iter, status, headers = response.get_wsgi_response(env) - assert status == '204 NO CONTENT' - assert 'Content-Length' not in headers - assert b''.join(app_iter) == b'' # ensure data will not be sent + assert status == "204 NO CONTENT" + assert "Content-Length" not in headers + assert b"".join(app_iter) == b"" # ensure data will not be sent def test_modified_url_encoding(): class ModifiedRequest(wrappers.Request): - url_charset = 'euc-kr' + url_charset = "euc-kr" - req = ModifiedRequest.from_values(u'/?foo=정상처리'.encode('euc-kr')) - strict_eq(req.args['foo'], u'정상처리') + req = ModifiedRequest.from_values(u"/?foo=정상처리".encode("euc-kr")) + strict_eq(req.args["foo"], u"정상처리") def test_request_method_case_sensitivity(): - req = wrappers.Request({'REQUEST_METHOD': 'get'}) - assert req.method == 'GET' + req = wrappers.Request({"REQUEST_METHOD": "get"}) + assert req.method == "GET" def test_is_xhr_warning(): @@ -1225,7 +1383,7 @@ def test_is_xhr_warning(): def test_write_length(): response = wrappers.Response() - length = response.stream.write(b'bar') + length = response.stream.write(b"bar") assert length == 3 @@ -1233,13 +1391,13 @@ def test_stream_zip(): import zipfile response = wrappers.Response() - with contextlib.closing(zipfile.ZipFile(response.stream, mode='w')) as z: + with contextlib.closing(zipfile.ZipFile(response.stream, mode="w")) as z: z.writestr("foo", b"bar") buffer = BytesIO(response.get_data()) - with contextlib.closing(zipfile.ZipFile(buffer, mode='r')) as z: - assert z.namelist() == ['foo'] - assert z.read('foo') == b'bar' + with contextlib.closing(zipfile.ZipFile(buffer, mode="r")) as z: + assert z.namelist() == ["foo"] + assert z.read("foo") == b"bar" class TestSetCookie(object): @@ -1247,49 +1405,103 @@ class TestSetCookie(object): def test_secure(self): response = wrappers.BaseResponse() - response.set_cookie('foo', value='bar', max_age=60, expires=0, - path='/blub', domain='example.org', secure=True, - samesite=None) - strict_eq(response.headers.to_wsgi_list(), [ - ('Content-Type', 'text/plain; charset=utf-8'), - ('Set-Cookie', 'foo=bar; Domain=example.org; Expires=Thu, ' - '01-Jan-1970 00:00:00 GMT; Max-Age=60; Secure; Path=/blub') - ]) + response.set_cookie( + "foo", + value="bar", + max_age=60, + expires=0, + path="/blub", + domain="example.org", + secure=True, + samesite=None, + ) + strict_eq( + response.headers.to_wsgi_list(), + [ + ("Content-Type", "text/plain; charset=utf-8"), + ( + "Set-Cookie", + "foo=bar; Domain=example.org; Expires=Thu, " + "01-Jan-1970 00:00:00 GMT; Max-Age=60; Secure; Path=/blub", + ), + ], + ) def test_httponly(self): response = wrappers.BaseResponse() - response.set_cookie('foo', value='bar', max_age=60, expires=0, - path='/blub', domain='example.org', secure=False, - httponly=True, samesite=None) - strict_eq(response.headers.to_wsgi_list(), [ - ('Content-Type', 'text/plain; charset=utf-8'), - ('Set-Cookie', 'foo=bar; Domain=example.org; Expires=Thu, ' - '01-Jan-1970 00:00:00 GMT; Max-Age=60; HttpOnly; Path=/blub') - ]) + response.set_cookie( + "foo", + value="bar", + max_age=60, + expires=0, + path="/blub", + domain="example.org", + secure=False, + httponly=True, + samesite=None, + ) + strict_eq( + response.headers.to_wsgi_list(), + [ + ("Content-Type", "text/plain; charset=utf-8"), + ( + "Set-Cookie", + "foo=bar; Domain=example.org; Expires=Thu, " + "01-Jan-1970 00:00:00 GMT; Max-Age=60; HttpOnly; Path=/blub", + ), + ], + ) def test_secure_and_httponly(self): response = wrappers.BaseResponse() - response.set_cookie('foo', value='bar', max_age=60, expires=0, - path='/blub', domain='example.org', secure=True, - httponly=True, samesite=None) - strict_eq(response.headers.to_wsgi_list(), [ - ('Content-Type', 'text/plain; charset=utf-8'), - ('Set-Cookie', 'foo=bar; Domain=example.org; Expires=Thu, ' - '01-Jan-1970 00:00:00 GMT; Max-Age=60; Secure; HttpOnly; ' - 'Path=/blub') - ]) + response.set_cookie( + "foo", + value="bar", + max_age=60, + expires=0, + path="/blub", + domain="example.org", + secure=True, + httponly=True, + samesite=None, + ) + strict_eq( + response.headers.to_wsgi_list(), + [ + ("Content-Type", "text/plain; charset=utf-8"), + ( + "Set-Cookie", + "foo=bar; Domain=example.org; Expires=Thu, " + "01-Jan-1970 00:00:00 GMT; Max-Age=60; Secure; HttpOnly; " + "Path=/blub", + ), + ], + ) def test_samesite(self): response = wrappers.BaseResponse() - response.set_cookie('foo', value='bar', max_age=60, expires=0, - path='/blub', domain='example.org', secure=False, - samesite='strict') - strict_eq(response.headers.to_wsgi_list(), [ - ('Content-Type', 'text/plain; charset=utf-8'), - ('Set-Cookie', 'foo=bar; Domain=example.org; Expires=Thu, ' - '01-Jan-1970 00:00:00 GMT; Max-Age=60; Path=/blub; ' - 'SameSite=Strict') - ]) + response.set_cookie( + "foo", + value="bar", + max_age=60, + expires=0, + path="/blub", + domain="example.org", + secure=False, + samesite="strict", + ) + strict_eq( + response.headers.to_wsgi_list(), + [ + ("Content-Type", "text/plain; charset=utf-8"), + ( + "Set-Cookie", + "foo=bar; Domain=example.org; Expires=Thu, " + "01-Jan-1970 00:00:00 GMT; Max-Age=60; Path=/blub; " + "SameSite=Strict", + ), + ], + ) class TestJSONMixin(object): @@ -1308,8 +1520,7 @@ class TestJSONMixin(object): def test_response(self): value = {u"ä": "b"} response = self.Response( - response=json.dumps(value), - content_type="application/json", + response=json.dumps(value), content_type="application/json" ) assert response.json == value @@ -1321,8 +1532,7 @@ class TestJSONMixin(object): def test_silent(self): request = self.Request.from_values( - data=b'{"a":}', - content_type="application/json", + data=b'{"a":}', content_type="application/json" ) assert request.get_json(silent=True) is None diff --git a/tests/test_wsgi.py b/tests/test_wsgi.py index 30c212a7..f99aa32c 100644 --- a/tests/test_wsgi.py +++ b/tests/test_wsgi.py @@ -14,208 +14,212 @@ import os import pytest -from tests import strict_eq +from . import strict_eq from werkzeug import wsgi -from werkzeug._compat import BytesIO, NativeStringIO, StringIO -from werkzeug.exceptions import BadRequest, ClientDisconnected -from werkzeug.test import Client, create_environ, run_wsgi_app +from werkzeug._compat import BytesIO +from werkzeug._compat import NativeStringIO +from werkzeug._compat import StringIO +from werkzeug.exceptions import BadRequest +from werkzeug.exceptions import ClientDisconnected +from werkzeug.test import Client +from werkzeug.test import create_environ +from werkzeug.test import run_wsgi_app from werkzeug.wrappers import BaseResponse -from werkzeug.wsgi import _RangeWrapper, ClosingIterator, wrap_file - - -@pytest.mark.parametrize(('environ', 'expect'), ( - pytest.param({ - 'HTTP_HOST': 'spam', - }, 'spam', id='host'), - pytest.param({ - 'HTTP_HOST': 'spam:80', - }, 'spam', id='host, strip http port'), - pytest.param({ - 'wsgi.url_scheme': 'https', - 'HTTP_HOST': 'spam:443', - }, 'spam', id='host, strip https port'), - pytest.param({ - 'HTTP_HOST': 'spam:8080', - }, 'spam:8080', id='host, custom port'), - pytest.param({ - 'HTTP_HOST': 'spam', - 'SERVER_NAME': 'eggs', - 'SERVER_PORT': '80', - }, 'spam', id='prefer host'), - pytest.param({ - 'SERVER_NAME': 'eggs', - 'SERVER_PORT': '80' - }, 'eggs', id='name, ignore http port'), - pytest.param({ - 'wsgi.url_scheme': 'https', - 'SERVER_NAME': 'eggs', - 'SERVER_PORT': '443' - }, 'eggs', id='name, ignore https port'), - pytest.param({ - 'SERVER_NAME': 'eggs', - 'SERVER_PORT': '8080' - }, 'eggs:8080', id='name, custom port'), - pytest.param({ - 'HTTP_HOST': 'ham', - 'HTTP_X_FORWARDED_HOST': 'eggs' - }, 'ham', id='ignore x-forwarded-host'), -)) +from werkzeug.wsgi import _RangeWrapper +from werkzeug.wsgi import ClosingIterator +from werkzeug.wsgi import wrap_file + + +@pytest.mark.parametrize( + ("environ", "expect"), + ( + pytest.param({"HTTP_HOST": "spam"}, "spam", id="host"), + pytest.param({"HTTP_HOST": "spam:80"}, "spam", id="host, strip http port"), + pytest.param( + {"wsgi.url_scheme": "https", "HTTP_HOST": "spam:443"}, + "spam", + id="host, strip https port", + ), + pytest.param({"HTTP_HOST": "spam:8080"}, "spam:8080", id="host, custom port"), + pytest.param( + {"HTTP_HOST": "spam", "SERVER_NAME": "eggs", "SERVER_PORT": "80"}, + "spam", + id="prefer host", + ), + pytest.param( + {"SERVER_NAME": "eggs", "SERVER_PORT": "80"}, + "eggs", + id="name, ignore http port", + ), + pytest.param( + {"wsgi.url_scheme": "https", "SERVER_NAME": "eggs", "SERVER_PORT": "443"}, + "eggs", + id="name, ignore https port", + ), + pytest.param( + {"SERVER_NAME": "eggs", "SERVER_PORT": "8080"}, + "eggs:8080", + id="name, custom port", + ), + pytest.param( + {"HTTP_HOST": "ham", "HTTP_X_FORWARDED_HOST": "eggs"}, + "ham", + id="ignore x-forwarded-host", + ), + ), +) def test_get_host(environ, expect): - environ.setdefault('wsgi.url_scheme', 'http') + environ.setdefault("wsgi.url_scheme", "http") assert wsgi.get_host(environ) == expect def test_get_host_validate_trusted_hosts(): - env = {'SERVER_NAME': 'example.org', 'SERVER_PORT': '80', - 'wsgi.url_scheme': 'http'} - assert wsgi.get_host(env, trusted_hosts=['.example.org']) == 'example.org' - pytest.raises(BadRequest, wsgi.get_host, env, - trusted_hosts=['example.com']) - env['SERVER_PORT'] = '8080' - assert wsgi.get_host(env, trusted_hosts=['.example.org:8080']) == 'example.org:8080' - pytest.raises(BadRequest, wsgi.get_host, env, - trusted_hosts=['.example.com']) - env = {'HTTP_HOST': 'example.org', 'wsgi.url_scheme': 'http'} - assert wsgi.get_host(env, trusted_hosts=['.example.org']) == 'example.org' - pytest.raises(BadRequest, wsgi.get_host, env, - trusted_hosts=['example.com']) + env = {"SERVER_NAME": "example.org", "SERVER_PORT": "80", "wsgi.url_scheme": "http"} + assert wsgi.get_host(env, trusted_hosts=[".example.org"]) == "example.org" + pytest.raises(BadRequest, wsgi.get_host, env, trusted_hosts=["example.com"]) + env["SERVER_PORT"] = "8080" + assert wsgi.get_host(env, trusted_hosts=[".example.org:8080"]) == "example.org:8080" + pytest.raises(BadRequest, wsgi.get_host, env, trusted_hosts=[".example.com"]) + env = {"HTTP_HOST": "example.org", "wsgi.url_scheme": "http"} + assert wsgi.get_host(env, trusted_hosts=[".example.org"]) == "example.org" + pytest.raises(BadRequest, wsgi.get_host, env, trusted_hosts=["example.com"]) def test_responder(): def foo(environ, start_response): - return BaseResponse(b'Test') + return BaseResponse(b"Test") + client = Client(wsgi.responder(foo), BaseResponse) - response = client.get('/') + response = client.get("/") assert response.status_code == 200 - assert response.data == b'Test' + assert response.data == b"Test" def test_pop_path_info(): - original_env = {'SCRIPT_NAME': '/foo', 'PATH_INFO': '/a/b///c'} + original_env = {"SCRIPT_NAME": "/foo", "PATH_INFO": "/a/b///c"} # regular path info popping def assert_tuple(script_name, path_info): - assert env.get('SCRIPT_NAME') == script_name - assert env.get('PATH_INFO') == path_info + assert env.get("SCRIPT_NAME") == script_name + assert env.get("PATH_INFO") == path_info + env = original_env.copy() - pop = lambda: wsgi.pop_path_info(env) - - assert_tuple('/foo', '/a/b///c') - assert pop() == 'a' - assert_tuple('/foo/a', '/b///c') - assert pop() == 'b' - assert_tuple('/foo/a/b', '///c') - assert pop() == 'c' - assert_tuple('/foo/a/b///c', '') + + def pop(): + return wsgi.pop_path_info(env) + + assert_tuple("/foo", "/a/b///c") + assert pop() == "a" + assert_tuple("/foo/a", "/b///c") + assert pop() == "b" + assert_tuple("/foo/a/b", "///c") + assert pop() == "c" + assert_tuple("/foo/a/b///c", "") assert pop() is None def test_peek_path_info(): - env = { - 'SCRIPT_NAME': '/foo', - 'PATH_INFO': '/aaa/b///c' - } + env = {"SCRIPT_NAME": "/foo", "PATH_INFO": "/aaa/b///c"} - assert wsgi.peek_path_info(env) == 'aaa' - assert wsgi.peek_path_info(env) == 'aaa' - assert wsgi.peek_path_info(env, charset=None) == b'aaa' - assert wsgi.peek_path_info(env, charset=None) == b'aaa' + assert wsgi.peek_path_info(env) == "aaa" + assert wsgi.peek_path_info(env) == "aaa" + assert wsgi.peek_path_info(env, charset=None) == b"aaa" + assert wsgi.peek_path_info(env, charset=None) == b"aaa" def test_path_info_and_script_name_fetching(): - env = create_environ(u'/\N{SNOWMAN}', u'http://example.com/\N{COMET}/') - assert wsgi.get_path_info(env) == u'/\N{SNOWMAN}' - assert wsgi.get_path_info(env, charset=None) == u'/\N{SNOWMAN}'.encode('utf-8') - assert wsgi.get_script_name(env) == u'/\N{COMET}' - assert wsgi.get_script_name(env, charset=None) == u'/\N{COMET}'.encode('utf-8') + env = create_environ(u"/\N{SNOWMAN}", u"http://example.com/\N{COMET}/") + assert wsgi.get_path_info(env) == u"/\N{SNOWMAN}" + assert wsgi.get_path_info(env, charset=None) == u"/\N{SNOWMAN}".encode("utf-8") + assert wsgi.get_script_name(env) == u"/\N{COMET}" + assert wsgi.get_script_name(env, charset=None) == u"/\N{COMET}".encode("utf-8") def test_query_string_fetching(): - env = create_environ(u'/?\N{SNOWMAN}=\N{COMET}') + env = create_environ(u"/?\N{SNOWMAN}=\N{COMET}") qs = wsgi.get_query_string(env) - strict_eq(qs, '%E2%98%83=%E2%98%84') + strict_eq(qs, "%E2%98%83=%E2%98%84") def test_limited_stream(): class RaisingLimitedStream(wsgi.LimitedStream): - def on_exhausted(self): - raise BadRequest('input stream exhausted') + raise BadRequest("input stream exhausted") - io = BytesIO(b'123456') + io = BytesIO(b"123456") stream = RaisingLimitedStream(io, 3) - strict_eq(stream.read(), b'123') + strict_eq(stream.read(), b"123") pytest.raises(BadRequest, stream.read) - io = BytesIO(b'123456') + io = BytesIO(b"123456") stream = RaisingLimitedStream(io, 3) strict_eq(stream.tell(), 0) - strict_eq(stream.read(1), b'1') + strict_eq(stream.read(1), b"1") strict_eq(stream.tell(), 1) - strict_eq(stream.read(1), b'2') + strict_eq(stream.read(1), b"2") strict_eq(stream.tell(), 2) - strict_eq(stream.read(1), b'3') + strict_eq(stream.read(1), b"3") strict_eq(stream.tell(), 3) pytest.raises(BadRequest, stream.read) - io = BytesIO(b'123456\nabcdefg') + io = BytesIO(b"123456\nabcdefg") stream = wsgi.LimitedStream(io, 9) - strict_eq(stream.readline(), b'123456\n') - strict_eq(stream.readline(), b'ab') + strict_eq(stream.readline(), b"123456\n") + strict_eq(stream.readline(), b"ab") - io = BytesIO(b'123456\nabcdefg') + io = BytesIO(b"123456\nabcdefg") stream = wsgi.LimitedStream(io, 9) - strict_eq(stream.readlines(), [b'123456\n', b'ab']) + strict_eq(stream.readlines(), [b"123456\n", b"ab"]) - io = BytesIO(b'123456\nabcdefg') + io = BytesIO(b"123456\nabcdefg") stream = wsgi.LimitedStream(io, 9) - strict_eq(stream.readlines(2), [b'12']) - strict_eq(stream.readlines(2), [b'34']) - strict_eq(stream.readlines(), [b'56\n', b'ab']) + strict_eq(stream.readlines(2), [b"12"]) + strict_eq(stream.readlines(2), [b"34"]) + strict_eq(stream.readlines(), [b"56\n", b"ab"]) - io = BytesIO(b'123456\nabcdefg') + io = BytesIO(b"123456\nabcdefg") stream = wsgi.LimitedStream(io, 9) - strict_eq(stream.readline(100), b'123456\n') + strict_eq(stream.readline(100), b"123456\n") - io = BytesIO(b'123456\nabcdefg') + io = BytesIO(b"123456\nabcdefg") stream = wsgi.LimitedStream(io, 9) - strict_eq(stream.readlines(100), [b'123456\n', b'ab']) + strict_eq(stream.readlines(100), [b"123456\n", b"ab"]) - io = BytesIO(b'123456') + io = BytesIO(b"123456") stream = wsgi.LimitedStream(io, 3) - strict_eq(stream.read(1), b'1') - strict_eq(stream.read(1), b'2') - strict_eq(stream.read(), b'3') - strict_eq(stream.read(), b'') + strict_eq(stream.read(1), b"1") + strict_eq(stream.read(1), b"2") + strict_eq(stream.read(), b"3") + strict_eq(stream.read(), b"") - io = BytesIO(b'123456') + io = BytesIO(b"123456") stream = wsgi.LimitedStream(io, 3) - strict_eq(stream.read(-1), b'123') + strict_eq(stream.read(-1), b"123") - io = BytesIO(b'123456') + io = BytesIO(b"123456") stream = wsgi.LimitedStream(io, 0) - strict_eq(stream.read(-1), b'') + strict_eq(stream.read(-1), b"") - io = StringIO(u'123456') + io = StringIO(u"123456") stream = wsgi.LimitedStream(io, 0) - strict_eq(stream.read(-1), u'') + strict_eq(stream.read(-1), u"") - io = StringIO(u'123\n456\n') + io = StringIO(u"123\n456\n") stream = wsgi.LimitedStream(io, 8) - strict_eq(list(stream), [u'123\n', u'456\n']) + strict_eq(list(stream), [u"123\n", u"456\n"]) def test_limited_stream_json_load(): stream = wsgi.LimitedStream(BytesIO(b'{"hello": "test"}'), 17) # flask.json adapts bytes to text with TextIOWrapper # this expects stream.readable() to exist and return true - stream = io.TextIOWrapper(io.BufferedReader(stream), 'UTF-8') + stream = io.TextIOWrapper(io.BufferedReader(stream), "UTF-8") data = json.load(stream) - assert data == {'hello': 'test'} + assert data == {"hello": "test"} def test_limited_stream_disconnection(): - io = BytesIO(b'A bit of content') + io = BytesIO(b"A bit of content") # disconnect detection on out of bytes stream = wsgi.LimitedStream(io, 255) @@ -223,7 +227,7 @@ def test_limited_stream_disconnection(): stream.read() # disconnect detection because file close - io = BytesIO(b'x' * 255) + io = BytesIO(b"x" * 255) io.close() stream = wsgi.LimitedStream(io, 255) with pytest.raises(ClientDisconnected): @@ -231,46 +235,58 @@ def test_limited_stream_disconnection(): def test_path_info_extraction(): - x = wsgi.extract_path_info('http://example.com/app', '/app/hello') - assert x == u'/hello' - x = wsgi.extract_path_info('http://example.com/app', - 'https://example.com/app/hello') - assert x == u'/hello' - x = wsgi.extract_path_info('http://example.com/app/', - 'https://example.com/app/hello') - assert x == u'/hello' - x = wsgi.extract_path_info('http://example.com/app/', - 'https://example.com/app') - assert x == u'/' - x = wsgi.extract_path_info(u'http://☃.net/', u'/fööbär') - assert x == u'/fööbär' - x = wsgi.extract_path_info(u'http://☃.net/x', u'http://☃.net/x/fööbär') - assert x == u'/fööbär' - - env = create_environ(u'/fööbär', u'http://☃.net/x/') - x = wsgi.extract_path_info(env, u'http://☃.net/x/fööbär') - assert x == u'/fööbär' - - x = wsgi.extract_path_info('http://example.com/app/', - 'https://example.com/a/hello') + x = wsgi.extract_path_info("http://example.com/app", "/app/hello") + assert x == u"/hello" + x = wsgi.extract_path_info( + "http://example.com/app", "https://example.com/app/hello" + ) + assert x == u"/hello" + x = wsgi.extract_path_info( + "http://example.com/app/", "https://example.com/app/hello" + ) + assert x == u"/hello" + x = wsgi.extract_path_info("http://example.com/app/", "https://example.com/app") + assert x == u"/" + x = wsgi.extract_path_info(u"http://☃.net/", u"/fööbär") + assert x == u"/fööbär" + x = wsgi.extract_path_info(u"http://☃.net/x", u"http://☃.net/x/fööbär") + assert x == u"/fööbär" + + env = create_environ(u"/fööbär", u"http://☃.net/x/") + x = wsgi.extract_path_info(env, u"http://☃.net/x/fööbär") + assert x == u"/fööbär" + + x = wsgi.extract_path_info("http://example.com/app/", "https://example.com/a/hello") assert x is None - x = wsgi.extract_path_info('http://example.com/app/', - 'https://example.com/app/hello', - collapse_http_schemes=False) + x = wsgi.extract_path_info( + "http://example.com/app/", + "https://example.com/app/hello", + collapse_http_schemes=False, + ) assert x is None def test_get_host_fallback(): - assert wsgi.get_host({ - 'SERVER_NAME': 'foobar.example.com', - 'wsgi.url_scheme': 'http', - 'SERVER_PORT': '80' - }) == 'foobar.example.com' - assert wsgi.get_host({ - 'SERVER_NAME': 'foobar.example.com', - 'wsgi.url_scheme': 'http', - 'SERVER_PORT': '81' - }) == 'foobar.example.com:81' + assert ( + wsgi.get_host( + { + "SERVER_NAME": "foobar.example.com", + "wsgi.url_scheme": "http", + "SERVER_PORT": "80", + } + ) + == "foobar.example.com" + ) + assert ( + wsgi.get_host( + { + "SERVER_NAME": "foobar.example.com", + "wsgi.url_scheme": "http", + "SERVER_PORT": "81", + } + ) + == "foobar.example.com:81" + ) def test_get_current_url_unicode(): @@ -289,150 +305,172 @@ def test_get_current_url_invalid_utf8(): def test_multi_part_line_breaks(): - data = 'abcdef\r\nghijkl\r\nmnopqrstuvwxyz\r\nABCDEFGHIJK' + data = "abcdef\r\nghijkl\r\nmnopqrstuvwxyz\r\nABCDEFGHIJK" test_stream = NativeStringIO(data) - lines = list(wsgi.make_line_iter(test_stream, limit=len(data), - buffer_size=16)) - assert lines == ['abcdef\r\n', 'ghijkl\r\n', 'mnopqrstuvwxyz\r\n', - 'ABCDEFGHIJK'] + lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=16)) + assert lines == ["abcdef\r\n", "ghijkl\r\n", "mnopqrstuvwxyz\r\n", "ABCDEFGHIJK"] - data = 'abc\r\nThis line is broken by the buffer length.' \ - '\r\nFoo bar baz' + data = "abc\r\nThis line is broken by the buffer length.\r\nFoo bar baz" test_stream = NativeStringIO(data) - lines = list(wsgi.make_line_iter(test_stream, limit=len(data), - buffer_size=24)) - assert lines == ['abc\r\n', 'This line is broken by the buffer ' - 'length.\r\n', 'Foo bar baz'] + lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=24)) + assert lines == [ + "abc\r\n", + "This line is broken by the buffer length.\r\n", + "Foo bar baz", + ] def test_multi_part_line_breaks_bytes(): - data = b'abcdef\r\nghijkl\r\nmnopqrstuvwxyz\r\nABCDEFGHIJK' + data = b"abcdef\r\nghijkl\r\nmnopqrstuvwxyz\r\nABCDEFGHIJK" test_stream = BytesIO(data) - lines = list(wsgi.make_line_iter(test_stream, limit=len(data), - buffer_size=16)) - assert lines == [b'abcdef\r\n', b'ghijkl\r\n', b'mnopqrstuvwxyz\r\n', - b'ABCDEFGHIJK'] - - data = b'abc\r\nThis line is broken by the buffer length.' \ - b'\r\nFoo bar baz' + lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=16)) + assert lines == [ + b"abcdef\r\n", + b"ghijkl\r\n", + b"mnopqrstuvwxyz\r\n", + b"ABCDEFGHIJK", + ] + + data = b"abc\r\nThis line is broken by the buffer length." b"\r\nFoo bar baz" test_stream = BytesIO(data) - lines = list(wsgi.make_line_iter(test_stream, limit=len(data), - buffer_size=24)) - assert lines == [b'abc\r\n', b'This line is broken by the buffer ' - b'length.\r\n', b'Foo bar baz'] + lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=24)) + assert lines == [ + b"abc\r\n", + b"This line is broken by the buffer " b"length.\r\n", + b"Foo bar baz", + ] def test_multi_part_line_breaks_problematic(): - data = 'abc\rdef\r\nghi' - for x in range(1, 10): + data = "abc\rdef\r\nghi" + for _ in range(1, 10): test_stream = NativeStringIO(data) - lines = list(wsgi.make_line_iter(test_stream, limit=len(data), - buffer_size=4)) - assert lines == ['abc\r', 'def\r\n', 'ghi'] + lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=4)) + assert lines == ["abc\r", "def\r\n", "ghi"] def test_iter_functions_support_iterators(): - data = ['abcdef\r\nghi', 'jkl\r\nmnopqrstuvwxyz\r', '\nABCDEFGHIJK'] + data = ["abcdef\r\nghi", "jkl\r\nmnopqrstuvwxyz\r", "\nABCDEFGHIJK"] lines = list(wsgi.make_line_iter(data)) - assert lines == ['abcdef\r\n', 'ghijkl\r\n', 'mnopqrstuvwxyz\r\n', - 'ABCDEFGHIJK'] + assert lines == ["abcdef\r\n", "ghijkl\r\n", "mnopqrstuvwxyz\r\n", "ABCDEFGHIJK"] def test_make_chunk_iter(): - data = [u'abcdefXghi', u'jklXmnopqrstuvwxyzX', u'ABCDEFGHIJK'] - rv = list(wsgi.make_chunk_iter(data, 'X')) - assert rv == [u'abcdef', u'ghijkl', u'mnopqrstuvwxyz', u'ABCDEFGHIJK'] + data = [u"abcdefXghi", u"jklXmnopqrstuvwxyzX", u"ABCDEFGHIJK"] + rv = list(wsgi.make_chunk_iter(data, "X")) + assert rv == [u"abcdef", u"ghijkl", u"mnopqrstuvwxyz", u"ABCDEFGHIJK"] - data = u'abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK' + data = u"abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK" test_stream = StringIO(data) - rv = list(wsgi.make_chunk_iter(test_stream, 'X', limit=len(data), - buffer_size=4)) - assert rv == [u'abcdef', u'ghijkl', u'mnopqrstuvwxyz', u'ABCDEFGHIJK'] + rv = list(wsgi.make_chunk_iter(test_stream, "X", limit=len(data), buffer_size=4)) + assert rv == [u"abcdef", u"ghijkl", u"mnopqrstuvwxyz", u"ABCDEFGHIJK"] def test_make_chunk_iter_bytes(): - data = [b'abcdefXghi', b'jklXmnopqrstuvwxyzX', b'ABCDEFGHIJK'] - rv = list(wsgi.make_chunk_iter(data, 'X')) - assert rv == [b'abcdef', b'ghijkl', b'mnopqrstuvwxyz', b'ABCDEFGHIJK'] + data = [b"abcdefXghi", b"jklXmnopqrstuvwxyzX", b"ABCDEFGHIJK"] + rv = list(wsgi.make_chunk_iter(data, "X")) + assert rv == [b"abcdef", b"ghijkl", b"mnopqrstuvwxyz", b"ABCDEFGHIJK"] - data = b'abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK' + data = b"abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK" test_stream = BytesIO(data) - rv = list(wsgi.make_chunk_iter(test_stream, 'X', limit=len(data), - buffer_size=4)) - assert rv == [b'abcdef', b'ghijkl', b'mnopqrstuvwxyz', b'ABCDEFGHIJK'] + rv = list(wsgi.make_chunk_iter(test_stream, "X", limit=len(data), buffer_size=4)) + assert rv == [b"abcdef", b"ghijkl", b"mnopqrstuvwxyz", b"ABCDEFGHIJK"] - data = b'abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK' + data = b"abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK" test_stream = BytesIO(data) - rv = list(wsgi.make_chunk_iter(test_stream, 'X', limit=len(data), - buffer_size=4, cap_at_buffer=True)) - assert rv == [b'abcd', b'ef', b'ghij', b'kl', b'mnop', b'qrst', b'uvwx', - b'yz', b'ABCD', b'EFGH', b'IJK'] + rv = list( + wsgi.make_chunk_iter( + test_stream, "X", limit=len(data), buffer_size=4, cap_at_buffer=True + ) + ) + assert rv == [ + b"abcd", + b"ef", + b"ghij", + b"kl", + b"mnop", + b"qrst", + b"uvwx", + b"yz", + b"ABCD", + b"EFGH", + b"IJK", + ] def test_lines_longer_buffer_size(): - data = '1234567890\n1234567890\n' + data = "1234567890\n1234567890\n" for bufsize in range(1, 15): - lines = list(wsgi.make_line_iter(NativeStringIO(data), limit=len(data), - buffer_size=4)) - assert lines == ['1234567890\n', '1234567890\n'] + lines = list( + wsgi.make_line_iter( + NativeStringIO(data), limit=len(data), buffer_size=bufsize + ) + ) + assert lines == ["1234567890\n", "1234567890\n"] def test_lines_longer_buffer_size_cap(): - data = '1234567890\n1234567890\n' + data = "1234567890\n1234567890\n" for bufsize in range(1, 15): - lines = list(wsgi.make_line_iter(NativeStringIO(data), limit=len(data), - buffer_size=4, cap_at_buffer=True)) - assert lines == ['1234', '5678', '90\n', '1234', '5678', '90\n'] + lines = list( + wsgi.make_line_iter( + NativeStringIO(data), + limit=len(data), + buffer_size=bufsize, + cap_at_buffer=True, + ) + ) + assert len(lines[0]) == bufsize or lines[0].endswith("\n") def test_range_wrapper(): - response = BaseResponse(b'Hello World') + response = BaseResponse(b"Hello World") range_wrapper = _RangeWrapper(response.response, 6, 4) - assert next(range_wrapper) == b'Worl' + assert next(range_wrapper) == b"Worl" - response = BaseResponse(b'Hello World') + response = BaseResponse(b"Hello World") range_wrapper = _RangeWrapper(response.response, 1, 0) with pytest.raises(StopIteration): next(range_wrapper) - response = BaseResponse(b'Hello World') + response = BaseResponse(b"Hello World") range_wrapper = _RangeWrapper(response.response, 6, 100) - assert next(range_wrapper) == b'World' + assert next(range_wrapper) == b"World" - response = BaseResponse((x for x in (b'He', b'll', b'o ', b'Wo', b'rl', b'd'))) + response = BaseResponse((x for x in (b"He", b"ll", b"o ", b"Wo", b"rl", b"d"))) range_wrapper = _RangeWrapper(response.response, 6, 4) assert not range_wrapper.seekable - assert next(range_wrapper) == b'Wo' - assert next(range_wrapper) == b'rl' + assert next(range_wrapper) == b"Wo" + assert next(range_wrapper) == b"rl" - response = BaseResponse((x for x in (b'He', b'll', b'o W', b'o', b'rld'))) + response = BaseResponse((x for x in (b"He", b"ll", b"o W", b"o", b"rld"))) range_wrapper = _RangeWrapper(response.response, 6, 4) - assert next(range_wrapper) == b'W' - assert next(range_wrapper) == b'o' - assert next(range_wrapper) == b'rl' + assert next(range_wrapper) == b"W" + assert next(range_wrapper) == b"o" + assert next(range_wrapper) == b"rl" with pytest.raises(StopIteration): next(range_wrapper) - response = BaseResponse((x for x in (b'Hello', b' World'))) + response = BaseResponse((x for x in (b"Hello", b" World"))) range_wrapper = _RangeWrapper(response.response, 1, 1) - assert next(range_wrapper) == b'e' + assert next(range_wrapper) == b"e" with pytest.raises(StopIteration): next(range_wrapper) - resources = os.path.join(os.path.dirname(__file__), 'res') + resources = os.path.join(os.path.dirname(__file__), "res") env = create_environ() - with open(os.path.join(resources, 'test.txt'), 'rb') as f: + with open(os.path.join(resources, "test.txt"), "rb") as f: response = BaseResponse(wrap_file(env, f)) range_wrapper = _RangeWrapper(response.response, 1, 2) assert range_wrapper.seekable - assert next(range_wrapper) == b'OU' + assert next(range_wrapper) == b"OU" with pytest.raises(StopIteration): next(range_wrapper) - with open(os.path.join(resources, 'test.txt'), 'rb') as f: + with open(os.path.join(resources, "test.txt"), "rb") as f: response = BaseResponse(wrap_file(env, f)) range_wrapper = _RangeWrapper(response.response, 2) - assert next(range_wrapper) == b'UND\n' + assert next(range_wrapper) == b"UND\n" with pytest.raises(StopIteration): next(range_wrapper) @@ -450,8 +488,8 @@ def test_closing_iterator(): # iterator. This ensures that ClosingIterator calls close on # the iterable (the object), not the iterator. def __iter__(self): - self.start('200 OK', [('Content-Type', 'text/plain')]) - yield 'some content' + self.start("200 OK", [("Content-Type", "text/plain")]) + yield "some content" def close(self): Namespace.got_close = True @@ -462,9 +500,8 @@ def test_closing_iterator(): def app(environ, start_response): return ClosingIterator(Response(environ, start_response), additional) - app_iter, status, headers = run_wsgi_app( - app, create_environ(), buffered=True) + app_iter, status, headers = run_wsgi_app(app, create_environ(), buffered=True) - assert ''.join(app_iter) == 'some content' + assert "".join(app_iter) == "some content" assert Namespace.got_close assert Namespace.got_additional @@ -19,9 +19,9 @@ deps = commands = coverage run -p -m pytest --tb=short --basetemp={envtmpdir} {posargs} [testenv:stylecheck] -deps = flake8 +deps = pre-commit skip_install = true -commands = flake8 {posargs} +commands = pre-commit run --all-files --show-diff-on-failure [testenv:docs-html] deps = pallets-sphinx-themes diff --git a/werkzeug-import-rewrite.py b/werkzeug-import-rewrite.py index 0d909d36..81af6a55 100644 --- a/werkzeug-import-rewrite.py +++ b/werkzeug-import-rewrite.py @@ -10,118 +10,194 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -import sys +import difflib import os -import re import posixpath -import difflib +import re +import sys -_from_import_re = re.compile(r'(\s*(>>>|\.\.\.)?\s*)from werkzeug import\s+') -_direct_usage = re.compile(r'(?<!`)(werkzeug\.)([a-zA-Z_][a-zA-Z0-9_]+)') +_from_import_re = re.compile(r"(\s*(>>>|\.\.\.)?\s*)from werkzeug import\s+") +_direct_usage = re.compile(r"(?<!`)(werkzeug\.)([a-zA-Z_][a-zA-Z0-9_]+)") # not necessarily in sync with current werkzeug/__init__.py all_by_module = { - 'werkzeug.debug': ['DebuggedApplication'], - 'werkzeug.local': ['Local', 'LocalManager', 'LocalProxy', 'LocalStack', - 'release_local'], - 'werkzeug.serving': ['run_simple'], - 'werkzeug.test': ['Client', 'EnvironBuilder', 'create_environ', - 'run_wsgi_app'], - 'werkzeug.testapp': ['test_app'], - 'werkzeug.exceptions': ['abort', 'Aborter'], - 'werkzeug.urls': ['url_decode', 'url_encode', 'url_quote', - 'url_quote_plus', 'url_unquote', 'url_unquote_plus', - 'url_fix', 'Href', 'iri_to_uri', 'uri_to_iri'], - 'werkzeug.formparser': ['parse_form_data'], - 'werkzeug.utils': ['escape', 'environ_property', 'append_slash_redirect', - 'redirect', 'cached_property', 'import_string', - 'dump_cookie', 'parse_cookie', 'unescape', - 'format_string', 'find_modules', 'header_property', - 'html', 'xhtml', 'HTMLBuilder', 'validate_arguments', - 'ArgumentValidationError', 'bind_arguments', - 'secure_filename'], - 'werkzeug.wsgi': ['get_current_url', 'get_host', 'pop_path_info', - 'peek_path_info', 'SharedDataMiddleware', - 'DispatcherMiddleware', 'ClosingIterator', 'FileWrapper', - 'make_line_iter', 'LimitedStream', 'responder', - 'wrap_file', 'extract_path_info'], - 'werkzeug.datastructures': ['MultiDict', 'CombinedMultiDict', 'Headers', - 'EnvironHeaders', 'ImmutableList', - 'ImmutableDict', 'ImmutableMultiDict', - 'TypeConversionDict', - 'ImmutableTypeConversionDict', 'Accept', - 'MIMEAccept', 'CharsetAccept', - 'LanguageAccept', 'RequestCacheControl', - 'ResponseCacheControl', 'ETags', 'HeaderSet', - 'WWWAuthenticate', 'Authorization', - 'FileMultiDict', 'CallbackDict', 'FileStorage', - 'OrderedMultiDict', 'ImmutableOrderedMultiDict' - ], - 'werkzeug.useragents': ['UserAgent'], - 'werkzeug.http': ['parse_etags', 'parse_date', 'http_date', 'cookie_date', - 'parse_cache_control_header', 'is_resource_modified', - 'parse_accept_header', 'parse_set_header', 'quote_etag', - 'unquote_etag', 'generate_etag', 'dump_header', - 'parse_list_header', 'parse_dict_header', - 'parse_authorization_header', - 'parse_www_authenticate_header', 'remove_entity_headers', - 'is_entity_header', 'remove_hop_by_hop_headers', - 'parse_options_header', 'dump_options_header', - 'is_hop_by_hop_header', 'unquote_header_value', - 'quote_header_value', 'HTTP_STATUS_CODES'], - 'werkzeug.wrappers': ['BaseResponse', 'BaseRequest', 'Request', 'Response', - 'AcceptMixin', 'ETagRequestMixin', - 'ETagResponseMixin', 'ResponseStreamMixin', - 'CommonResponseDescriptorsMixin', 'UserAgentMixin', - 'AuthorizationMixin', 'WWWAuthenticateMixin', - 'CommonRequestDescriptorsMixin'], - 'werkzeug.security': ['generate_password_hash', 'check_password_hash'], + "werkzeug.debug": ["DebuggedApplication"], + "werkzeug.local": [ + "Local", + "LocalManager", + "LocalProxy", + "LocalStack", + "release_local", + ], + "werkzeug.serving": ["run_simple"], + "werkzeug.test": ["Client", "EnvironBuilder", "create_environ", "run_wsgi_app"], + "werkzeug.testapp": ["test_app"], + "werkzeug.exceptions": ["abort", "Aborter"], + "werkzeug.urls": [ + "url_decode", + "url_encode", + "url_quote", + "url_quote_plus", + "url_unquote", + "url_unquote_plus", + "url_fix", + "Href", + "iri_to_uri", + "uri_to_iri", + ], + "werkzeug.formparser": ["parse_form_data"], + "werkzeug.utils": [ + "escape", + "environ_property", + "append_slash_redirect", + "redirect", + "cached_property", + "import_string", + "dump_cookie", + "parse_cookie", + "unescape", + "format_string", + "find_modules", + "header_property", + "html", + "xhtml", + "HTMLBuilder", + "validate_arguments", + "ArgumentValidationError", + "bind_arguments", + "secure_filename", + ], + "werkzeug.wsgi": [ + "get_current_url", + "get_host", + "pop_path_info", + "peek_path_info", + "SharedDataMiddleware", + "DispatcherMiddleware", + "ClosingIterator", + "FileWrapper", + "make_line_iter", + "LimitedStream", + "responder", + "wrap_file", + "extract_path_info", + ], + "werkzeug.datastructures": [ + "MultiDict", + "CombinedMultiDict", + "Headers", + "EnvironHeaders", + "ImmutableList", + "ImmutableDict", + "ImmutableMultiDict", + "TypeConversionDict", + "ImmutableTypeConversionDict", + "Accept", + "MIMEAccept", + "CharsetAccept", + "LanguageAccept", + "RequestCacheControl", + "ResponseCacheControl", + "ETags", + "HeaderSet", + "WWWAuthenticate", + "Authorization", + "FileMultiDict", + "CallbackDict", + "FileStorage", + "OrderedMultiDict", + "ImmutableOrderedMultiDict", + ], + "werkzeug.useragents": ["UserAgent"], + "werkzeug.http": [ + "parse_etags", + "parse_date", + "http_date", + "cookie_date", + "parse_cache_control_header", + "is_resource_modified", + "parse_accept_header", + "parse_set_header", + "quote_etag", + "unquote_etag", + "generate_etag", + "dump_header", + "parse_list_header", + "parse_dict_header", + "parse_authorization_header", + "parse_www_authenticate_header", + "remove_entity_headers", + "is_entity_header", + "remove_hop_by_hop_headers", + "parse_options_header", + "dump_options_header", + "is_hop_by_hop_header", + "unquote_header_value", + "quote_header_value", + "HTTP_STATUS_CODES", + ], + "werkzeug.wrappers": [ + "BaseResponse", + "BaseRequest", + "Request", + "Response", + "AcceptMixin", + "ETagRequestMixin", + "ETagResponseMixin", + "ResponseStreamMixin", + "CommonResponseDescriptorsMixin", + "UserAgentMixin", + "AuthorizationMixin", + "WWWAuthenticateMixin", + "CommonRequestDescriptorsMixin", + ], + "werkzeug.security": ["generate_password_hash", "check_password_hash"], # the undocumented easteregg ;-) - 'werkzeug._internal': ['_easteregg'] + "werkzeug._internal": ["_easteregg"], } by_item = {} -for module, names in all_by_module.iteritems(): +for module, names in all_by_module.items(): for name in names: by_item[name] = module def find_module(item): - return by_item.get(item, 'werkzeug') + return by_item.get(item, "werkzeug") def complete_fromlist(fromlist, lineiter): fromlist = fromlist.strip() if not fromlist: return [] - if fromlist[0] == '(': - if fromlist[-1] == ')': - return fromlist[1:-1].strip().split(',') - fromlist = fromlist[1:].strip().split(',') + if fromlist[0] == "(": + if fromlist[-1] == ")": + return fromlist[1:-1].strip().split(",") + fromlist = fromlist[1:].strip().split(",") for line in lineiter: line = line.strip() abort = False - if line.endswith(')'): + if line.endswith(")"): line = line[:-1] abort = True - fromlist.extend(line.split(',')) + fromlist.extend(line.split(",")) if abort: break return fromlist - elif fromlist[-1] == '\\': - fromlist = fromlist[:-1].strip().split(',') + elif fromlist[-1] == "\\": + fromlist = fromlist[:-1].strip().split(",") for line in lineiter: line = line.strip() abort = True - if line.endswith('\\'): + if line.endswith("\\"): abort = False line = line[:-1] - fromlist.extend(line.split(',')) + fromlist.extend(line.split(",")) if abort: break return fromlist - return fromlist.split(',') + return fromlist.split(",") def rewrite_from_imports(fromlist, indentation, lineiter): @@ -132,43 +208,48 @@ def rewrite_from_imports(fromlist, indentation, lineiter): continue if len(item) == 1: parsed_items.append((item[0], None)) - elif len(item) == 3 and item[1] == 'as': + elif len(item) == 3 and item[1] == "as": parsed_items.append((item[0], item[2])) else: - raise ValueError('invalid syntax for import') + raise ValueError("invalid syntax for import") new_imports = {} for item, alias in parsed_items: new_imports.setdefault(find_module(item), []).append((item, alias)) for module_name, items in sorted(new_imports.items()): - fromlist_items = sorted(['%s%s' % ( - item, - alias is not None and (' as ' + alias) or '' - ) for (item, alias) in items], reverse=True) + fromlist_items = sorted( + [ + "%s%s" % (item, alias is not None and (" as " + alias) or "") + for (item, alias) in items + ], + reverse=True, + ) - prefix = '%sfrom %s import ' % (indentation, module_name) + prefix = "%sfrom %s import " % (indentation, module_name) item_buffer = [] while fromlist_items: item_buffer.append(fromlist_items.pop()) - fromlist = ', '.join(item_buffer) + fromlist = ", ".join(item_buffer) if len(fromlist) + len(prefix) > 79: - yield prefix + ', '.join(item_buffer[:-1]) + ', \\' + yield prefix + ", ".join(item_buffer[:-1]) + ", \\" item_buffer = [item_buffer[-1]] # doctest continuations - indentation = indentation.replace('>', '.') - prefix = indentation + ' ' - yield prefix + ', '.join(item_buffer) + indentation = indentation.replace(">", ".") + prefix = indentation + " " + yield prefix + ", ".join(item_buffer) def inject_imports(lines, imports): pos = 0 for idx, line in enumerate(lines): - if re.match(r'(from|import)\s+werkzeug', line): + if re.match(r"(from|import)\s+werkzeug", line): pos = idx break - lines[pos:pos] = ['from %s import %s' % (mod, ', '.join(sorted(attrs))) - for mod, attrs in sorted(imports.items())] + lines[pos:pos] = [ + "from %s import %s" % (mod, ", ".join(sorted(attrs))) + for mod, attrs in sorted(imports.items()) + ] def rewrite_file(filename): @@ -182,48 +263,48 @@ def rewrite_file(filename): # rewrite from imports match = _from_import_re.search(line) if match is not None: - fromlist = line[match.end():] - new_file.extend(rewrite_from_imports(fromlist, - match.group(1), - lineiter)) + fromlist = line[match.end() :] + new_file.extend(rewrite_from_imports(fromlist, match.group(1), lineiter)) continue def _handle_match(match): # rewrite attribute access to 'werkzeug' attr = match.group(2) mod = find_module(attr) - if mod == 'werkzeug': + if mod == "werkzeug": return match.group(0) deferred_imports.setdefault(mod, []).append(attr) return attr + new_file.append(_direct_usage.sub(_handle_match, line)) if deferred_imports: inject_imports(new_file, deferred_imports) for line in difflib.unified_diff( - old_file, new_file, - posixpath.normpath(posixpath.join('a', filename)), - posixpath.normpath(posixpath.join('b', filename)), - lineterm='' + old_file, + new_file, + posixpath.normpath(posixpath.join("a", filename)), + posixpath.normpath(posixpath.join("b", filename)), + lineterm="", ): print(line) def rewrite_in_folders(folders): for folder in folders: - for dirpath, dirnames, filenames in os.walk(folder): + for dirpath, _dirnames, filenames in os.walk(folder): for filename in filenames: filename = os.path.join(dirpath, filename) - if filename.endswith(('.rst', '.py')): + if filename.endswith((".rst", ".py")): rewrite_file(filename) def main(): if len(sys.argv) == 1: - print('usage: werkzeug-import-rewrite.py [folders]') + print("usage: werkzeug-import-rewrite.py [folders]") sys.exit(1) rewrite_in_folders(sys.argv[1:]) -if __name__ == '__main__': +if __name__ == "__main__": main() |
