summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Lord <davidism@gmail.com>2019-02-13 11:44:18 -0800
committerDavid Lord <davidism@gmail.com>2019-03-08 08:01:31 -0800
commitab6150fa49afc61b0c5eed6d9545d03d1958e384 (patch)
treead5f13c9c2775ca59cc8e82ec124c4e065a65d1b
parent048d707d25685e6aea675c53945ceb7619e60344 (diff)
downloadwerkzeug-code-style.tar.gz
apply code stylecode-style
* reorder-python-imports * line fixers * black * flake8
-rw-r--r--.editorconfig13
-rw-r--r--.gitattributes3
-rw-r--r--.pre-commit-config.yaml28
-rwxr-xr-xbench/wzbench.py316
-rw-r--r--docs/routing.rst1
-rw-r--r--docs/test.rst2
-rw-r--r--docs/unicode.rst2
-rw-r--r--examples/README.rst8
-rw-r--r--examples/contrib/securecookie.py27
-rw-r--r--examples/contrib/sessions.py23
-rw-r--r--examples/cookieauth.py63
-rw-r--r--examples/coolmagic/__init__.py2
-rw-r--r--examples/coolmagic/application.py33
-rw-r--r--examples/coolmagic/helpers.py2
-rw-r--r--examples/coolmagic/utils.py41
-rw-r--r--examples/coolmagic/views/static.py10
-rw-r--r--examples/couchy/README1
-rw-r--r--examples/couchy/application.py30
-rw-r--r--examples/couchy/models.py22
-rw-r--r--examples/couchy/utils.py51
-rw-r--r--examples/couchy/views.py64
-rw-r--r--examples/cupoftee/__init__.py2
-rw-r--r--examples/cupoftee/application.py56
-rw-r--r--examples/cupoftee/db.py10
-rw-r--r--examples/cupoftee/network.py43
-rw-r--r--examples/cupoftee/pages.py40
-rw-r--r--examples/cupoftee/utils.py2
-rw-r--r--examples/httpbasicauth.py23
-rw-r--r--examples/i18nurls/__init__.py2
-rw-r--r--examples/i18nurls/application.py44
-rw-r--r--examples/i18nurls/urls.py27
-rw-r--r--examples/i18nurls/views.py29
-rwxr-xr-xexamples/manage-coolmagic.py40
-rwxr-xr-xexamples/manage-couchy.py50
-rwxr-xr-xexamples/manage-cupoftee.py36
-rwxr-xr-xexamples/manage-i18nurls.py40
-rwxr-xr-xexamples/manage-plnt.py91
-rwxr-xr-xexamples/manage-shorty.py50
-rwxr-xr-xexamples/manage-simplewiki.py51
-rwxr-xr-xexamples/manage-webpylike.py46
-rw-r--r--examples/partial/complex_routing.py60
-rw-r--r--examples/plnt/__init__.py2
-rw-r--r--examples/plnt/database.py67
-rw-r--r--examples/plnt/sync.py46
-rw-r--r--examples/plnt/utils.py56
-rw-r--r--examples/plnt/views.py23
-rw-r--r--examples/plnt/webapp.py25
-rw-r--r--examples/shortly/shortly.py108
-rw-r--r--examples/shorty/application.py24
-rw-r--r--examples/shorty/models.py31
-rw-r--r--examples/shorty/utils.py64
-rw-r--r--examples/shorty/views.py58
-rw-r--r--examples/simplewiki/__init__.py2
-rw-r--r--examples/simplewiki/actions.py190
-rw-r--r--examples/simplewiki/application.py38
-rw-r--r--examples/simplewiki/database.py95
-rw-r--r--examples/simplewiki/specialpages.py35
-rw-r--r--examples/simplewiki/templates/macros.xml2
-rw-r--r--examples/simplewiki/utils.py68
-rw-r--r--examples/upload.py35
-rw-r--r--examples/webpylike/example.py13
-rw-r--r--examples/webpylike/webpylike.py21
-rw-r--r--setup.cfg27
-rw-r--r--setup.py73
-rw-r--r--src/werkzeug/__init__.py243
-rw-r--r--src/werkzeug/_compat.py73
-rw-r--r--src/werkzeug/_internal.py194
-rw-r--r--src/werkzeug/_reloader.py83
-rw-r--r--src/werkzeug/contrib/atom.py225
-rw-r--r--src/werkzeug/contrib/cache.py169
-rw-r--r--src/werkzeug/contrib/fixers.py61
-rw-r--r--src/werkzeug/contrib/iterio.py85
-rw-r--r--src/werkzeug/contrib/securecookie.py102
-rw-r--r--src/werkzeug/contrib/sessions.py136
-rw-r--r--src/werkzeug/contrib/wrappers.py109
-rw-r--r--src/werkzeug/datastructures.py700
-rw-r--r--src/werkzeug/debug/__init__.py265
-rw-r--r--src/werkzeug/debug/console.py61
-rw-r--r--src/werkzeug/debug/repr.py153
-rw-r--r--src/werkzeug/debug/tbtools.py265
-rw-r--r--src/werkzeug/exceptions.py287
-rw-r--r--src/werkzeug/filesystem.py30
-rw-r--r--src/werkzeug/formparser.py287
-rw-r--r--src/werkzeug/http.py513
-rw-r--r--src/werkzeug/local.py55
-rw-r--r--src/werkzeug/middleware/http_proxy.py5
-rw-r--r--src/werkzeug/middleware/lint.py2
-rw-r--r--src/werkzeug/middleware/profiler.py1
-rw-r--r--src/werkzeug/middleware/proxy_fix.py2
-rw-r--r--src/werkzeug/middleware/shared_data.py3
-rw-r--r--src/werkzeug/posixemulation.py39
-rw-r--r--src/werkzeug/routing.py710
-rw-r--r--src/werkzeug/security.py90
-rw-r--r--src/werkzeug/serving.py516
-rw-r--r--src/werkzeug/test.py459
-rw-r--r--src/werkzeug/testapp.py111
-rw-r--r--src/werkzeug/urls.py436
-rw-r--r--src/werkzeug/useragents.py141
-rw-r--r--src/werkzeug/utils.py333
-rw-r--r--src/werkzeug/wrappers/accept.py14
-rw-r--r--src/werkzeug/wrappers/auth.py12
-rw-r--r--src/werkzeug/wrappers/base_request.py232
-rw-r--r--src/werkzeug/wrappers/base_response.py176
-rw-r--r--src/werkzeug/wrappers/common_descriptors.py349
-rw-r--r--src/werkzeug/wrappers/etag.py122
-rw-r--r--src/werkzeug/wrappers/json.py8
-rw-r--r--src/werkzeug/wrappers/request.py11
-rw-r--r--src/werkzeug/wrappers/response.py20
-rw-r--r--src/werkzeug/wrappers/user_agent.py9
-rw-r--r--src/werkzeug/wsgi.py230
-rw-r--r--tests/conftest.py87
-rw-r--r--tests/contrib/cache/conftest.py11
-rw-r--r--tests/contrib/cache/test_cache.py187
-rw-r--r--tests/contrib/test_atom.py50
-rw-r--r--tests/contrib/test_fixers.py139
-rw-r--r--tests/contrib/test_iterio.py120
-rw-r--r--tests/contrib/test_securecookie.py35
-rw-r--r--tests/contrib/test_sessions.py6
-rw-r--r--tests/contrib/test_wrappers.py81
-rw-r--r--tests/hypothesis/test_urls.py10
-rw-r--r--tests/multipart/firefox3-2png1txt/request.http (renamed from tests/multipart/firefox3-2png1txt/request.txt)bin1739 -> 1739 bytes
-rw-r--r--tests/multipart/firefox3-2png1txt/text.txt2
-rw-r--r--tests/multipart/firefox3-2pnglongtext/request.http (renamed from tests/multipart/firefox3-2pnglongtext/request.txt)bin2042 -> 2042 bytes
-rw-r--r--tests/multipart/firefox3-2pnglongtext/text.txt2
-rw-r--r--tests/multipart/ie6-2png1txt/request.http (renamed from tests/multipart/ie6-2png1txt/request.txt)bin1798 -> 1798 bytes
-rw-r--r--tests/multipart/ie6-2png1txt/text.txt2
-rw-r--r--tests/multipart/ie7_full_path_request.http (renamed from tests/multipart/ie7_full_path_request.txt)bin30044 -> 30044 bytes
-rw-r--r--tests/multipart/opera8-2png1txt/request.http (renamed from tests/multipart/opera8-2png1txt/request.txt)bin1740 -> 1740 bytes
-rw-r--r--tests/multipart/opera8-2png1txt/text.txt2
-rw-r--r--tests/multipart/test_collect.py53
-rw-r--r--tests/multipart/webkit3-2png1txt/request.http (renamed from tests/multipart/webkit3-2png1txt/request.txt)bin2408 -> 2408 bytes
-rw-r--r--tests/multipart/webkit3-2png1txt/text.txt2
-rw-r--r--tests/res/chunked.http (renamed from tests/res/chunked.txt)0
-rw-r--r--tests/test_compat.py32
-rw-r--r--tests/test_datastructures.py901
-rw-r--r--tests/test_debug.py222
-rw-r--r--tests/test_exceptions.py80
-rw-r--r--tests/test_formparser.py590
-rw-r--r--tests/test_http.py784
-rw-r--r--tests/test_internal.py43
-rw-r--r--tests/test_local.py30
-rw-r--r--tests/test_routing.py1336
-rw-r--r--tests/test_security.py210
-rw-r--r--tests/test_serving.py608
-rw-r--r--tests/test_test.py574
-rw-r--r--tests/test_urls.py458
-rw-r--r--tests/test_utils.py294
-rw-r--r--tests/test_wrappers.py1334
-rw-r--r--tests/test_wsgi.py523
-rw-r--r--tox.ini4
-rw-r--r--werkzeug-import-rewrite.py285
151 files changed, 11013 insertions, 8798 deletions
diff --git a/.editorconfig b/.editorconfig
new file mode 100644
index 00000000..e32c8029
--- /dev/null
+++ b/.editorconfig
@@ -0,0 +1,13 @@
+root = true
+
+[*]
+indent_style = space
+indent_size = 4
+insert_final_newline = true
+trim_trailing_whitespace = true
+end_of_line = lf
+charset = utf-8
+max_line_length = 88
+
+[*.{yml,yaml,json,js,css,html}]
+indent_size = 2
diff --git a/.gitattributes b/.gitattributes
index 96d79dad..5946e823 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -1,2 +1 @@
-tests/res/chunked.txt binary
-
+tests/**/*.http binary
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 00000000..2fb46619
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,28 @@
+repos:
+ - repo: https://github.com/asottile/reorder_python_imports
+ rev: v1.4.0
+ hooks:
+ - id: reorder-python-imports
+ name: Reorder Python imports (src, tests)
+ files: "^(?!examples/)"
+ args: ["--application-directories", ".:src"]
+ - id: reorder-python-imports
+ name: Reorder Python imports (examples)
+ files: "^examples/"
+ args: ["--application-directories", "examples"]
+ - repo: https://github.com/ambv/black
+ rev: 18.9b0
+ hooks:
+ - id: black
+ - repo: https://gitlab.com/pycqa/flake8
+ rev: 3.7.7
+ hooks:
+ - id: flake8
+ additional_dependencies: [flake8-bugbear]
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v2.1.0
+ hooks:
+ - id: check-byte-order-marker
+ - id: trailing-whitespace
+ - id: end-of-file-fixer
+ exclude: "^tests/.*.http$"
diff --git a/bench/wzbench.py b/bench/wzbench.py
index f75742f7..6270b620 100755
--- a/bench/wzbench.py
+++ b/bench/wzbench.py
@@ -11,17 +11,20 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-from __future__ import division, print_function
-import os
+from __future__ import division
+from __future__ import print_function
+
import gc
-import sys
+import os
import subprocess
+import sys
+from timeit import default_timer as timer
+from types import FunctionType
+
try:
from cStringIO import StringIO
except ImportError:
from io import StringIO
-from timeit import default_timer as timer
-from types import FunctionType
PY2 = sys.version_info[0] == 2
@@ -30,10 +33,9 @@ if not PY2:
# create a new module where we later store all the werkzeug attributes.
-wz = type(sys)('werkzeug_nonlazy')
-sys.path.insert(0, '<DUMMY>')
-null_out = open(os.devnull, 'w')
-
+wz = type(sys)("werkzeug_nonlazy")
+sys.path.insert(0, "<DUMMY>")
+null_out = open(os.devnull, "w")
# ±4% are ignored
TOLERANCE = 0.04
@@ -47,8 +49,9 @@ def find_hg_tag(path):
"""Returns the current node or tag for the given path."""
tags = {}
try:
- client = subprocess.Popen(['hg', 'cat', '-r', 'tip', '.hgtags'],
- stdout=subprocess.PIPE, cwd=path)
+ client = subprocess.Popen(
+ ["hg", "cat", "-r", "tip", ".hgtags"], stdout=subprocess.PIPE, cwd=path
+ )
for line in client.communicate()[0].splitlines():
line = line.strip()
if not line:
@@ -58,8 +61,9 @@ def find_hg_tag(path):
except OSError:
return
- client = subprocess.Popen(['hg', 'parent', '--template', '#node#'],
- stdout=subprocess.PIPE, cwd=path)
+ client = subprocess.Popen(
+ ["hg", "parent", "--template", "#node#"], stdout=subprocess.PIPE, cwd=path
+ )
tip = client.communicate()[0].strip()
tag = tags.get(tip)
@@ -75,11 +79,12 @@ def load_werkzeug(path):
# get rid of already imported stuff
wz.__dict__.clear()
for key in sys.modules.keys():
- if key.startswith('werkzeug.') or key == 'werkzeug':
+ if key.startswith("werkzeug.") or key == "werkzeug":
sys.modules.pop(key, None)
# import werkzeug again.
import werkzeug
+
for key in werkzeug.__all__:
setattr(wz, key, getattr(werkzeug, key))
@@ -88,18 +93,18 @@ def load_werkzeug(path):
# get the real version from the setup file
try:
- f = open(os.path.join(path, 'setup.py'))
+ f = open(os.path.join(path, "setup.py"))
except IOError:
pass
else:
try:
for line in f:
line = line.strip()
- if line.startswith('version='):
- return line[8:].strip(' \t,')[1:-1], hg_tag
+ if line.startswith("version="):
+ return line[8:].strip(" \t,")[1:-1], hg_tag
finally:
f.close()
- print('Unknown werkzeug version loaded', file=sys.stderr)
+ print("Unknown werkzeug version loaded", file=sys.stderr)
sys.exit(2)
@@ -115,14 +120,14 @@ def format_func(func):
name = func.__name__
else:
name = func
- if name.startswith('time_'):
+ if name.startswith("time_"):
name = name[5:]
- return name.replace('_', ' ').title()
+ return name.replace("_", " ").title()
def bench(func):
"""Times a single function."""
- sys.stdout.write('%44s ' % format_func(func))
+ sys.stdout.write("%44s " % format_func(func))
sys.stdout.flush()
# figure out how many times we have to run the function to
@@ -130,7 +135,7 @@ def bench(func):
for i in xrange(3, 10):
rounds = 1 << i
t = timer()
- for x in xrange(rounds):
+ for _ in xrange(rounds):
func()
if timer() - t >= 0.2:
break
@@ -142,14 +147,14 @@ def bench(func):
gc.disable()
try:
t = timer()
- for x in xrange(rounds):
+ for _ in xrange(rounds):
func()
return (timer() - t) / rounds * 1000
finally:
gc.enable()
delta = median(_run() for x in xrange(TEST_RUNS))
- sys.stdout.write('%.4f\n' % delta)
+ sys.stdout.write("%.4f\n" % delta)
sys.stdout.flush()
return delta
@@ -158,17 +163,33 @@ def bench(func):
def main():
"""The main entrypoint."""
from optparse import OptionParser
- parser = OptionParser(usage='%prog [options]')
- parser.add_option('--werkzeug-path', '-p', dest='path', default='..',
- help='the path to the werkzeug package. defaults to cwd')
- parser.add_option('--compare', '-c', dest='compare', nargs=2,
- default=False, help='compare two hg nodes of Werkzeug')
- parser.add_option('--init-compare', dest='init_compare',
- action='store_true', default=False,
- help='Initializes the comparison feature')
+
+ parser = OptionParser(usage="%prog [options]")
+ parser.add_option(
+ "--werkzeug-path",
+ "-p",
+ dest="path",
+ default="..",
+ help="the path to the werkzeug package. defaults to cwd",
+ )
+ parser.add_option(
+ "--compare",
+ "-c",
+ dest="compare",
+ nargs=2,
+ default=False,
+ help="compare two hg nodes of Werkzeug",
+ )
+ parser.add_option(
+ "--init-compare",
+ dest="init_compare",
+ action="store_true",
+ default=False,
+ help="Initializes the comparison feature",
+ )
options, args = parser.parse_args()
if args:
- parser.error('Script takes no arguments')
+ parser.error("Script takes no arguments")
if options.compare:
compare(*options.compare)
elif options.init_compare:
@@ -179,65 +200,70 @@ def main():
def init_compare():
"""Initializes the comparison feature."""
- print('Initializing comparison feature')
- subprocess.Popen(['hg', 'clone', '..', 'a']).wait()
- subprocess.Popen(['hg', 'clone', '..', 'b']).wait()
+ print("Initializing comparison feature")
+ subprocess.Popen(["hg", "clone", "..", "a"]).wait()
+ subprocess.Popen(["hg", "clone", "..", "b"]).wait()
def compare(node1, node2):
"""Compares two Werkzeug hg versions."""
- if not os.path.isdir('a'):
- print('error: comparison feature not initialized', file=sys.stderr)
+ if not os.path.isdir("a"):
+ print("error: comparison feature not initialized", file=sys.stderr)
sys.exit(4)
- print('=' * 80)
- print('WERKZEUG INTERNAL BENCHMARK -- COMPARE MODE'.center(80))
- print('-' * 80)
-
- def _error(msg):
- print('error:', msg, file=sys.stderr)
- sys.exit(1)
+ print("=" * 80)
+ print("WERKZEUG INTERNAL BENCHMARK -- COMPARE MODE".center(80))
+ print("-" * 80)
def _hg_update(repo, node):
- hg = lambda *x: subprocess.call(['hg'] + list(x), cwd=repo,
- stdout=null_out, stderr=null_out)
- hg('revert', '-a', '--no-backup')
- client = subprocess.Popen(['hg', 'status', '--unknown', '-n', '-0'],
- stdout=subprocess.PIPE, cwd=repo)
+ def hg(*x):
+ return subprocess.call(
+ ["hg"] + list(x), cwd=repo, stdout=null_out, stderr=null_out
+ )
+
+ hg("revert", "-a", "--no-backup")
+ client = subprocess.Popen(
+ ["hg", "status", "--unknown", "-n", "-0"], stdout=subprocess.PIPE, cwd=repo
+ )
unknown = client.communicate()[0]
if unknown:
- client = subprocess.Popen(['xargs', '-0', 'rm', '-f'], cwd=repo,
- stdout=null_out, stdin=subprocess.PIPE)
+ client = subprocess.Popen(
+ ["xargs", "-0", "rm", "-f"],
+ cwd=repo,
+ stdout=null_out,
+ stdin=subprocess.PIPE,
+ )
client.communicate(unknown)
- hg('pull', '../..')
- hg('update', node)
- if node == 'tip':
- diff = subprocess.Popen(['hg', 'diff'], cwd='..',
- stdout=subprocess.PIPE).communicate()[0]
+ hg("pull", "../..")
+ hg("update", node)
+ if node == "tip":
+ diff = subprocess.Popen(
+ ["hg", "diff"], cwd="..", stdout=subprocess.PIPE
+ ).communicate()[0]
if diff:
- client = subprocess.Popen(['hg', 'import', '--no-commit', '-'],
- cwd=repo, stdout=null_out,
- stdin=subprocess.PIPE)
+ client = subprocess.Popen(
+ ["hg", "import", "--no-commit", "-"],
+ cwd=repo,
+ stdout=null_out,
+ stdin=subprocess.PIPE,
+ )
client.communicate(diff)
- _hg_update('a', node1)
- _hg_update('b', node2)
- d1 = run('a', no_header=True)
- d2 = run('b', no_header=True)
+ _hg_update("a", node1)
+ _hg_update("b", node2)
+ d1 = run("a", no_header=True)
+ d2 = run("b", no_header=True)
- print('DIRECT COMPARISON'.center(80))
- print('-' * 80)
+ print("DIRECT COMPARISON".center(80))
+ print("-" * 80)
for key in sorted(d1):
delta = d1[key] - d2[key]
- if abs(1 - d1[key] / d2[key]) < TOLERANCE or \
- abs(delta) < MIN_RESOLUTION:
- delta = '=='
+ if abs(1 - d1[key] / d2[key]) < TOLERANCE or abs(delta) < MIN_RESOLUTION:
+ delta = "=="
else:
- delta = '%+.4f (%+d%%)' % \
- (delta, round(d2[key] / d1[key] * 100 - 100))
- print('%36s %.4f %.4f %s' %
- (format_func(key), d1[key], d2[key], delta))
- print('-' * 80)
+ delta = "%+.4f (%+d%%)" % (delta, round(d2[key] / d1[key] * 100 - 100))
+ print("%36s %.4f %.4f %s" % (format_func(key), d1[key], d2[key], delta))
+ print("-" * 80)
def run(path, no_header=False):
@@ -245,45 +271,47 @@ def run(path, no_header=False):
wz_version, hg_tag = load_werkzeug(path)
result = {}
if not no_header:
- print('=' * 80)
- print('WERKZEUG INTERNAL BENCHMARK'.center(80))
- print('-' * 80)
- print('Path: %s' % path)
- print('Version: %s' % wz_version)
+ print("=" * 80)
+ print("WERKZEUG INTERNAL BENCHMARK".center(80))
+ print("-" * 80)
+ print("Path: %s" % path)
+ print("Version: %s" % wz_version)
if hg_tag is not None:
- print('HG Tag: %s' % hg_tag)
- print('-' * 80)
+ print("HG Tag: %s" % hg_tag)
+ print("-" * 80)
for key, value in sorted(globals().items()):
- if key.startswith('time_'):
- before = globals().get('before_' + key[5:])
+ if key.startswith("time_"):
+ before = globals().get("before_" + key[5:])
if before:
before()
result[key] = bench(value)
- after = globals().get('after_' + key[5:])
+ after = globals().get("after_" + key[5:])
if after:
after()
- print('-' * 80)
+ print("-" * 80)
return result
URL_DECODED_DATA = dict((str(x), str(x)) for x in xrange(100))
-URL_ENCODED_DATA = '&'.join('%s=%s' % x for x in URL_DECODED_DATA.items())
-MULTIPART_ENCODED_DATA = '\n'.join((
- '--foo',
- 'Content-Disposition: form-data; name=foo',
- '',
- 'this is just bar',
- '--foo',
- 'Content-Disposition: form-data; name=bar',
- '',
- 'blafasel',
- '--foo',
- 'Content-Disposition: form-data; name=foo; filename=wzbench.py',
- 'Content-Type: text/plain',
- '',
- open(__file__.rstrip('c')).read(),
- '--foo--'
-))
+URL_ENCODED_DATA = "&".join("%s=%s" % x for x in URL_DECODED_DATA.items())
+MULTIPART_ENCODED_DATA = "\n".join(
+ (
+ "--foo",
+ "Content-Disposition: form-data; name=foo",
+ "",
+ "this is just bar",
+ "--foo",
+ "Content-Disposition: form-data; name=bar",
+ "",
+ "blafasel",
+ "--foo",
+ "Content-Disposition: form-data; name=foo; filename=wzbench.py",
+ "Content-Type: text/plain",
+ "",
+ open(__file__.rstrip("c")).read(),
+ "--foo--",
+ )
+)
MULTIDICT = None
REQUEST = None
TEST_ENV = None
@@ -304,10 +332,10 @@ def time_parse_form_data_multipart():
# from_values which is known to be slowish in 0.5.1 and higher.
# we don't want to bench two things at once.
environ = {
- 'REQUEST_METHOD': 'POST',
- 'CONTENT_TYPE': 'multipart/form-data; boundary=foo',
- 'wsgi.input': StringIO(MULTIPART_ENCODED_DATA),
- 'CONTENT_LENGTH': str(len(MULTIPART_ENCODED_DATA))
+ "REQUEST_METHOD": "POST",
+ "CONTENT_TYPE": "multipart/form-data; boundary=foo",
+ "wsgi.input": StringIO(MULTIPART_ENCODED_DATA),
+ "CONTENT_LENGTH": str(len(MULTIPART_ENCODED_DATA)),
}
request = wz.Request(environ)
request.form
@@ -315,11 +343,11 @@ def time_parse_form_data_multipart():
def before_multidict_lookup_hit():
global MULTIDICT
- MULTIDICT = wz.MultiDict({'foo': 'bar'})
+ MULTIDICT = wz.MultiDict({"foo": "bar"})
def time_multidict_lookup_hit():
- MULTIDICT['foo']
+ MULTIDICT["foo"]
def after_multidict_lookup_hit():
@@ -334,7 +362,7 @@ def before_multidict_lookup_miss():
def time_multidict_lookup_miss():
try:
- MULTIDICT['foo']
+ MULTIDICT["foo"]
except KeyError:
pass
@@ -351,31 +379,33 @@ def time_cached_property():
return 42
f = Foo()
- for x in xrange(60):
+ for _ in xrange(60):
f.x
def before_request_form_access():
global REQUEST
- data = 'foo=bar&blah=blub'
- REQUEST = wz.Request({
- 'CONTENT_LENGTH': str(len(data)),
- 'wsgi.input': StringIO(data),
- 'REQUEST_METHOD': 'POST',
- 'wsgi.version': (1, 0),
- 'QUERY_STRING': data,
- 'CONTENT_TYPE': 'application/x-www-form-urlencoded',
- 'PATH_INFO': '/',
- 'SCRIPT_NAME': ''
- })
+ data = "foo=bar&blah=blub"
+ REQUEST = wz.Request(
+ {
+ "CONTENT_LENGTH": str(len(data)),
+ "wsgi.input": StringIO(data),
+ "REQUEST_METHOD": "POST",
+ "wsgi.version": (1, 0),
+ "QUERY_STRING": data,
+ "CONTENT_TYPE": "application/x-www-form-urlencoded",
+ "PATH_INFO": "/",
+ "SCRIPT_NAME": "",
+ }
+ )
def time_request_form_access():
- for x in xrange(30):
+ for _ in xrange(30):
REQUEST.path
REQUEST.script_root
- REQUEST.args['foo']
- REQUEST.form['foo']
+ REQUEST.args["foo"]
+ REQUEST.form["foo"]
def after_request_form_access():
@@ -384,12 +414,14 @@ def after_request_form_access():
def time_request_from_values():
- wz.Request.from_values(base_url='http://www.google.com/',
- query_string='foo=bar&blah=blaz',
- input_stream=StringIO(MULTIPART_ENCODED_DATA),
- content_length=len(MULTIPART_ENCODED_DATA),
- content_type='multipart/form-data; '
- 'boundary=foo', method='POST')
+ wz.Request.from_values(
+ base_url="http://www.google.com/",
+ query_string="foo=bar&blah=blaz",
+ input_stream=StringIO(MULTIPART_ENCODED_DATA),
+ content_length=len(MULTIPART_ENCODED_DATA),
+ content_type="multipart/form-data; boundary=foo",
+ method="POST",
+ )
def before_request_shallow_init():
@@ -407,16 +439,14 @@ def after_request_shallow_init():
def time_response_iter_performance():
- resp = wz.Response(u'Hällo Wörld ' * 1000,
- mimetype='text/html')
- for item in resp({'REQUEST_METHOD': 'GET'}, lambda *s: None):
+ resp = wz.Response(u"Hällo Wörld " * 1000, mimetype="text/html")
+ for _ in resp({"REQUEST_METHOD": "GET"}, lambda *s: None):
pass
def time_response_iter_head_performance():
- resp = wz.Response(u'Hällo Wörld ' * 1000,
- mimetype='text/html')
- for item in resp({'REQUEST_METHOD': 'HEAD'}, lambda *s: None):
+ resp = wz.Response(u"Hällo Wörld " * 1000, mimetype="text/html")
+ for _ in resp({"REQUEST_METHOD": "HEAD"}, lambda *s: None):
pass
@@ -427,9 +457,9 @@ def before_local_manager_dispatch():
def time_local_manager_dispatch():
- for x in xrange(10):
+ for _ in xrange(10):
LOCAL.x = 42
- for x in xrange(10):
+ for _ in xrange(10):
LOCAL.x
@@ -440,14 +470,14 @@ def after_local_manager_dispatch():
def before_html_builder():
global TABLE
- TABLE = [['col 1', 'col 2', 'col 3', '4', '5', '6'] for x in range(10)]
+ TABLE = [["col 1", "col 2", "col 3", "4", "5", "6"] for x in range(10)]
def time_html_builder():
html_rows = []
for row in TABLE: # noqa
- html_cols = [wz.html.td(col, class_='col') for col in row]
- html_rows.append(wz.html.tr(class_='row', *html_cols))
+ html_cols = [wz.html.td(col, class_="col") for col in row]
+ html_rows.append(wz.html.tr(class_="row", *html_cols))
wz.html.table(*html_rows)
@@ -456,9 +486,9 @@ def after_html_builder():
TABLE = None
-if __name__ == '__main__':
+if __name__ == "__main__":
os.chdir(os.path.dirname(__file__) or os.path.curdir)
try:
main()
except KeyboardInterrupt:
- print('\nInterrupted!', file=sys.stderr)
+ print("\nInterrupted!", file=sys.stderr)
diff --git a/docs/routing.rst b/docs/routing.rst
index 9db3e177..f5a768cc 100644
--- a/docs/routing.rst
+++ b/docs/routing.rst
@@ -221,4 +221,3 @@ Variable parts are of course also possible in the host section::
Rule('/', endpoint='www_index', host='www.example.com'),
Rule('/', endpoint='user_index', host='<user>.example.com')
], host_matching=True)
-
diff --git a/docs/test.rst b/docs/test.rst
index eb0130b3..c7e213f8 100644
--- a/docs/test.rst
+++ b/docs/test.rst
@@ -139,7 +139,7 @@ Testing API
A dict with values that are used to override the generated environ.
.. attribute:: input_stream
-
+
The optional input stream. This and :attr:`form` / :attr:`files`
is mutually exclusive. Also do not provide this stream if the
request method is not `POST` / `PUT` or something comparable.
diff --git a/docs/unicode.rst b/docs/unicode.rst
index 0dc977a7..d8eaa84c 100644
--- a/docs/unicode.rst
+++ b/docs/unicode.rst
@@ -95,7 +95,7 @@ argument that behaves like the `errors` parameter of the builtin string method
Unlike the regular python decoding Werkzeug does not raise an
:exc:`UnicodeDecodeError` if the decoding failed but an
:exc:`~exceptions.HTTPUnicodeError` which
-is a direct subclass of `UnicodeError` and the `BadRequest` HTTP exception.
+is a direct subclass of `UnicodeError` and the `BadRequest` HTTP exception.
The reason is that if this exception is not caught by the application but
a catch-all for HTTP exceptions exists a default `400 BAD REQUEST` error
page is displayed.
diff --git a/examples/README.rst b/examples/README.rst
index 593394a1..2b9df866 100644
--- a/examples/README.rst
+++ b/examples/README.rst
@@ -23,7 +23,7 @@ find in real life :-)
A simple Wiki implementation.
Requirements:
-
+
- SQLAlchemy
- Creoleparser >= 0.7
- genshi
@@ -50,7 +50,7 @@ find in real life :-)
A planet called plnt, pronounce plant.
Requirements:
-
+
- SQLAlchemy
- Jinja2
- feedparser
@@ -76,7 +76,7 @@ find in real life :-)
A tinyurl clone for the Werkzeug tutorial.
Requirements:
-
+
- SQLAlchemy
- Jinja2
@@ -98,7 +98,7 @@ find in real life :-)
Like shorty, but implemented using CouchDB.
Requirements :
-
+
- werkzeug : http://werkzeug.pocoo.org
- jinja : http://jinja.pocoo.org
- couchdb 0.72 & above : https://couchdb.apache.org/
diff --git a/examples/contrib/securecookie.py b/examples/contrib/securecookie.py
index a0a566d1..2f6544d3 100644
--- a/examples/contrib/securecookie.py
+++ b/examples/contrib/securecookie.py
@@ -9,15 +9,16 @@
:license: BSD-3-Clause
"""
from time import asctime
-from werkzeug.serving import run_simple
-from werkzeug.wrappers import BaseRequest, BaseResponse
+
from werkzeug.contrib.securecookie import SecureCookie
+from werkzeug.serving import run_simple
+from werkzeug.wrappers import BaseRequest
+from werkzeug.wrappers import BaseResponse
-SECRET_KEY = 'V\x8a$m\xda\xe9\xc3\x0f|f\x88\xbccj>\x8bI^3+'
+SECRET_KEY = "V\x8a$m\xda\xe9\xc3\x0f|f\x88\xbccj>\x8bI^3+"
class Request(BaseRequest):
-
def __init__(self, environ):
BaseRequest.__init__(self, environ)
self.session = SecureCookie.load_cookie(self, secret_key=SECRET_KEY)
@@ -28,23 +29,23 @@ def index(request):
def get_time(request):
- return 'Time: %s' % request.session.get('time', 'not set')
+ return "Time: %s" % request.session.get("time", "not set")
def set_time(request):
- request.session['time'] = time = asctime()
- return 'Time set to %s' % time
+ request.session["time"] = time = asctime()
+ return "Time set to %s" % time
def application(environ, start_response):
request = Request(environ)
- response = BaseResponse({
- 'get': get_time,
- 'set': set_time
- }.get(request.path.strip('/'), index)(request), mimetype='text/html')
+ response = BaseResponse(
+ {"get": get_time, "set": set_time}.get(request.path.strip("/"), index)(request),
+ mimetype="text/html",
+ )
request.session.save_cookie(response)
return response(environ, start_response)
-if __name__ == '__main__':
- run_simple('localhost', 5000, application)
+if __name__ == "__main__":
+ run_simple("localhost", 5000, application)
diff --git a/examples/contrib/sessions.py b/examples/contrib/sessions.py
index ecf464a2..c5eef576 100644
--- a/examples/contrib/sessions.py
+++ b/examples/contrib/sessions.py
@@ -1,11 +1,11 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+from werkzeug.contrib.sessions import SessionMiddleware
+from werkzeug.contrib.sessions import SessionStore
from werkzeug.serving import run_simple
-from werkzeug.contrib.sessions import SessionStore, SessionMiddleware
class MemorySessionStore(SessionStore):
-
def __init__(self, session_class=None):
SessionStore.__init__(self, session_class=None)
self.sessions = {}
@@ -23,21 +23,22 @@ class MemorySessionStore(SessionStore):
def application(environ, start_response):
- session = environ['werkzeug.session']
- session['visit_count'] = session.get('visit_count', 0) + 1
+ session = environ["werkzeug.session"]
+ session["visit_count"] = session.get("visit_count", 0) + 1
- start_response('200 OK', [('Content-Type', 'text/html')])
- return ['''
- <!doctype html>
+ start_response("200 OK", [("Content-Type", "text/html")])
+ return [
+ """<!doctype html>
<title>Session Example</title>
<h1>Session Example</h1>
- <p>You visited this page %d times.</p>
- ''' % session['visit_count']]
+ <p>You visited this page %d times.</p>"""
+ % session["visit_count"]
+ ]
def make_app():
return SessionMiddleware(application, MemorySessionStore())
-if __name__ == '__main__':
- run_simple('localhost', 5000, make_app())
+if __name__ == "__main__":
+ run_simple("localhost", 5000, make_app())
diff --git a/examples/cookieauth.py b/examples/cookieauth.py
index 64b5ae0d..ba23bda4 100644
--- a/examples/cookieauth.py
+++ b/examples/cookieauth.py
@@ -10,25 +10,25 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-from werkzeug.serving import run_simple
-from werkzeug.utils import cached_property, escape, redirect
-from werkzeug.wrappers import Request, Response
from werkzeug.contrib.securecookie import SecureCookie
+from werkzeug.serving import run_simple
+from werkzeug.utils import cached_property
+from werkzeug.utils import escape
+from werkzeug.utils import redirect
+from werkzeug.wrappers import Request
+from werkzeug.wrappers import Response
# don't use this key but a different one; you could just use
# os.unrandom(20) to get something random. Changing this key
# invalidates all sessions at once.
-SECRET_KEY = '\xfa\xdd\xb8z\xae\xe0}4\x8b\xea'
+SECRET_KEY = "\xfa\xdd\xb8z\xae\xe0}4\x8b\xea"
# the cookie name for the session
-COOKIE_NAME = 'session'
+COOKIE_NAME = "session"
# the users that may access
-USERS = {
- 'admin': 'default',
- 'user1': 'default'
-}
+USERS = {"admin": "default", "user1": "default"}
class AppRequest(Request):
@@ -36,11 +36,11 @@ class AppRequest(Request):
def logout(self):
"""Log the user out."""
- self.session.pop('username', None)
+ self.session.pop("username", None)
def login(self, username):
"""Log the user in."""
- self.session['username'] = username
+ self.session["username"] = username
@property
def logged_in(self):
@@ -50,7 +50,7 @@ class AppRequest(Request):
@property
def user(self):
"""The user that is logged in."""
- return self.session.get('username')
+ return self.session.get("username")
@cached_property
def session(self):
@@ -61,16 +61,16 @@ class AppRequest(Request):
def login_form(request):
- error = ''
- if request.method == 'POST':
- username = request.form.get('username')
- password = request.form.get('password')
+ error = ""
+ if request.method == "POST":
+ username = request.form.get("username")
+ password = request.form.get("password")
if password and USERS.get(username) == password:
request.login(username)
- return redirect('')
- error = '<p>Invalid credentials'
- return Response('''
- <title>Login</title><h1>Login</h1>
+ return redirect("")
+ error = "<p>Invalid credentials"
+ return Response(
+ """<title>Login</title><h1>Login</h1>
<p>Not logged in.
%s
<form action="" method="post">
@@ -79,23 +79,28 @@ def login_form(request):
<input type="text" name="username" size=20>
<input type="password" name="password", size=20>
<input type="submit" value="Login">
- </form>''' % error, mimetype='text/html')
+ </form>"""
+ % error,
+ mimetype="text/html",
+ )
def index(request):
- return Response('''
- <title>Logged in</title>
+ return Response(
+ """<title>Logged in</title>
<h1>Logged in</h1>
<p>Logged in as %s
- <p><a href="/?do=logout">Logout</a>
- ''' % escape(request.user), mimetype='text/html')
+ <p><a href="/?do=logout">Logout</a>"""
+ % escape(request.user),
+ mimetype="text/html",
+ )
@AppRequest.application
def application(request):
- if request.args.get('do') == 'logout':
+ if request.args.get("do") == "logout":
request.logout()
- response = redirect('.')
+ response = redirect(".")
elif request.logged_in:
response = index(request)
else:
@@ -104,5 +109,5 @@ def application(request):
return response
-if __name__ == '__main__':
- run_simple('localhost', 4000, application)
+if __name__ == "__main__":
+ run_simple("localhost", 4000, application)
diff --git a/examples/coolmagic/__init__.py b/examples/coolmagic/__init__.py
index 88d5db80..0526f1d6 100644
--- a/examples/coolmagic/__init__.py
+++ b/examples/coolmagic/__init__.py
@@ -8,4 +8,4 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-from coolmagic.application import make_app
+from .application import make_app
diff --git a/examples/coolmagic/application.py b/examples/coolmagic/application.py
index 10ffd2a2..730f4d83 100644
--- a/examples/coolmagic/application.py
+++ b/examples/coolmagic/application.py
@@ -12,11 +12,18 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-from os import path, listdir
-from coolmagic.utils import Request, local_manager
+from os import listdir
+from os import path
+
+from werkzeug.exceptions import HTTPException
+from werkzeug.exceptions import NotFound
from werkzeug.middleware.shared_data import SharedDataMiddleware
-from werkzeug.routing import Map, Rule, RequestRedirect
-from werkzeug.exceptions import HTTPException, NotFound
+from werkzeug.routing import Map
+from werkzeug.routing import RequestRedirect
+from werkzeug.routing import Rule
+
+from .utils import local_manager
+from .utils import Request
class CoolMagicApplication(object):
@@ -27,17 +34,17 @@ class CoolMagicApplication(object):
def __init__(self, config):
self.config = config
- for fn in listdir(path.join(path.dirname(__file__), 'views')):
- if fn.endswith('.py') and fn != '__init__.py':
- __import__('coolmagic.views.' + fn[:-3])
+ for fn in listdir(path.join(path.dirname(__file__), "views")):
+ if fn.endswith(".py") and fn != "__init__.py":
+ __import__("coolmagic.views." + fn[:-3])
from coolmagic.utils import exported_views
+
rules = [
# url for shared data. this will always be unmatched
# because either the middleware or the webserver
# handles that request first.
- Rule('/public/<path:file>',
- endpoint='shared_data')
+ Rule("/public/<path:file>", endpoint="shared_data")
]
self.views = {}
for endpoint, (func, rule, extra) in exported_views.items():
@@ -53,7 +60,7 @@ class CoolMagicApplication(object):
endpoint, args = urls.match(req.path)
resp = self.views[endpoint](**args)
except NotFound:
- resp = self.views['static.not_found']()
+ resp = self.views["static.not_found"]()
except (HTTPException, RequestRedirect) as e:
resp = e
return resp(environ, start_response)
@@ -68,9 +75,9 @@ def make_app(config=None):
app = CoolMagicApplication(config)
# static stuff
- app = SharedDataMiddleware(app, {
- '/public': path.join(path.dirname(__file__), 'public')
- })
+ app = SharedDataMiddleware(
+ app, {"/public": path.join(path.dirname(__file__), "public")}
+ )
# clean up locals
app = local_manager.make_middleware(app)
diff --git a/examples/coolmagic/helpers.py b/examples/coolmagic/helpers.py
index 54638335..4cd4ac45 100644
--- a/examples/coolmagic/helpers.py
+++ b/examples/coolmagic/helpers.py
@@ -8,7 +8,7 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-from coolmagic.utils import ThreadedRequest
+from .utils import ThreadedRequest
#: a thread local proxy request object
diff --git a/examples/coolmagic/utils.py b/examples/coolmagic/utils.py
index b69f95a4..f4cf20d5 100644
--- a/examples/coolmagic/utils.py
+++ b/examples/coolmagic/utils.py
@@ -11,17 +11,21 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-from os.path import dirname, join
-from jinja2 import Environment, FileSystemLoader
-from werkzeug.local import Local, LocalManager
-from werkzeug.wrappers import BaseRequest, BaseResponse
+from os.path import dirname
+from os.path import join
+
+from jinja2 import Environment
+from jinja2 import FileSystemLoader
+from werkzeug.local import Local
+from werkzeug.local import LocalManager
+from werkzeug.wrappers import BaseRequest
+from werkzeug.wrappers import BaseResponse
local = Local()
local_manager = LocalManager([local])
template_env = Environment(
- loader=FileSystemLoader(join(dirname(__file__), 'templates'),
- use_memcache=False)
+ loader=FileSystemLoader(join(dirname(__file__), "templates"), use_memcache=False)
)
exported_views = {}
@@ -31,19 +35,23 @@ def export(string, template=None, **extra):
Decorator for registering view functions and adding
templates to it.
"""
+
def wrapped(f):
- endpoint = (f.__module__ + '.' + f.__name__)[16:]
+ endpoint = (f.__module__ + "." + f.__name__)[16:]
if template is not None:
old_f = f
+
def f(**kwargs):
rv = old_f(**kwargs)
if not isinstance(rv, Response):
rv = TemplateResponse(template, **(rv or {}))
return rv
+
f.__name__ = old_f.__name__
f.__doc__ = old_f.__doc__
exported_views[endpoint] = (f, string, extra)
return f
+
return wrapped
@@ -59,7 +67,8 @@ class Request(BaseRequest):
The concrete request object used in the WSGI application.
It has some helper functions that can be used to build URLs.
"""
- charset = 'utf-8'
+
+ charset = "utf-8"
def __init__(self, environ, url_adapter):
BaseRequest.__init__(self, environ)
@@ -74,9 +83,8 @@ class ThreadedRequest(object):
"""
def __getattr__(self, name):
- if name == '__members__':
- return [x for x in dir(local.request) if not
- x.startswith('_')]
+ if name == "__members__":
+ return [x for x in dir(local.request) if not x.startswith("_")]
return getattr(local.request, name)
def __setattr__(self, name, value):
@@ -87,8 +95,9 @@ class Response(BaseResponse):
"""
The concrete response object for the WSGI application.
"""
- charset = 'utf-8'
- default_mimetype = 'text/html'
+
+ charset = "utf-8"
+ default_mimetype = "text/html"
class TemplateResponse(Response):
@@ -98,9 +107,7 @@ class TemplateResponse(Response):
def __init__(self, template_name, **values):
from coolmagic import helpers
- values.update(
- request=local.request,
- h=helpers
- )
+
+ values.update(request=local.request, h=helpers)
template = template_env.get_template(template_name)
Response.__init__(self, template.render(values))
diff --git a/examples/coolmagic/views/static.py b/examples/coolmagic/views/static.py
index edb09372..f4f95409 100644
--- a/examples/coolmagic/views/static.py
+++ b/examples/coolmagic/views/static.py
@@ -11,22 +11,22 @@
from coolmagic.utils import export
-@export('/', template='static/index.html')
+@export("/", template="static/index.html")
def index():
pass
-@export('/about', template='static/about.html')
+@export("/about", template="static/about.html")
def about():
pass
-@export('/broken')
+@export("/broken")
def broken():
- raise RuntimeError('that\'s really broken')
+ raise RuntimeError("that's really broken")
-@export(None, template='static/not_found.html')
+@export(None, template="static/not_found.html")
def not_found():
"""
This function is always executed if an url does not
diff --git a/examples/couchy/README b/examples/couchy/README
index 1821ed11..24960448 100644
--- a/examples/couchy/README
+++ b/examples/couchy/README
@@ -5,4 +5,3 @@ Requirements :
- jinja : http://jinja.pocoo.org
- couchdb 0.72 & above : https://couchdb.apache.org/
- couchdb-python 0.3 & above : https://github.com/djc/couchdb-python
-
diff --git a/examples/couchy/application.py b/examples/couchy/application.py
index f3a56a63..b958dcbc 100644
--- a/examples/couchy/application.py
+++ b/examples/couchy/application.py
@@ -1,27 +1,28 @@
from couchdb.client import Server
-from couchy.utils import STATIC_PATH, local, local_manager, \
- url_map
+from werkzeug.exceptions import HTTPException
+from werkzeug.exceptions import NotFound
from werkzeug.middleware.shared_data import SharedDataMiddleware
from werkzeug.wrappers import Request
from werkzeug.wsgi import ClosingIterator
-from werkzeug.exceptions import HTTPException, NotFound
-from couchy import views
-from couchy.models import URL
+from . import views
+from .models import URL
+from .utils import local
+from .utils import local_manager
+from .utils import STATIC_PATH
+from .utils import url_map
-class Couchy(object):
+class Couchy(object):
def __init__(self, db_uri):
local.application = self
server = Server(db_uri)
try:
- db = server.create('urls')
- except:
- db = server['urls']
- self.dispatch = SharedDataMiddleware(self.dispatch, {
- '/static': STATIC_PATH
- })
+ db = server.create("urls")
+ except Exception:
+ db = server["urls"]
+ self.dispatch = SharedDataMiddleware(self.dispatch, {"/static": STATIC_PATH})
URL.db = db
@@ -38,8 +39,9 @@ class Couchy(object):
response.status_code = 404
except HTTPException as e:
response = e
- return ClosingIterator(response(environ, start_response),
- [local_manager.cleanup])
+ return ClosingIterator(
+ response(environ, start_response), [local_manager.cleanup]
+ )
def __call__(self, environ, start_response):
return self.dispatch(environ, start_response)
diff --git a/examples/couchy/models.py b/examples/couchy/models.py
index 4621a744..a0b50ca1 100644
--- a/examples/couchy/models.py
+++ b/examples/couchy/models.py
@@ -1,6 +1,12 @@
from datetime import datetime
-from couchdb.mapping import Document, TextField, BooleanField, DateTimeField
-from couchy.utils import url_for, get_random_uid
+
+from couchdb.mapping import BooleanField
+from couchdb.mapping import DateTimeField
+from couchdb.mapping import Document
+from couchdb.mapping import TextField
+
+from .utils import get_random_uid
+from .utils import url_for
class URL(Document):
@@ -19,13 +25,15 @@ class URL(Document):
return URL.db.query(code)
def store(self):
- if getattr(self._data, 'id', None) is None:
+ if getattr(self._data, "id", None) is None:
new_id = self.shorty_id if self.shorty_id else None
while 1:
id = new_id if new_id else get_random_uid()
try:
- docid = URL.db.resource.put(content=self._data, path='/%s/' % str(id))['id']
- except:
+ docid = URL.db.resource.put(
+ content=self._data, path="/%s/" % str(id)
+ )["id"]
+ except Exception:
continue
if docid:
break
@@ -36,7 +44,7 @@ class URL(Document):
@property
def short_url(self):
- return url_for('link', uid=self.id, _external=True)
+ return url_for("link", uid=self.id, _external=True)
def __repr__(self):
- return '<URL %r>' % self.id
+ return "<URL %r>" % self.id
diff --git a/examples/couchy/utils.py b/examples/couchy/utils.py
index 4fe666a3..571a7ed9 100644
--- a/examples/couchy/utils.py
+++ b/examples/couchy/utils.py
@@ -1,49 +1,62 @@
from os import path
-from random import sample, randrange
-from jinja2 import Environment, FileSystemLoader
-from werkzeug.local import Local, LocalManager
+from random import randrange
+from random import sample
+
+from jinja2 import Environment
+from jinja2 import FileSystemLoader
+from werkzeug.local import Local
+from werkzeug.local import LocalManager
+from werkzeug.routing import Map
+from werkzeug.routing import Rule
from werkzeug.urls import url_parse
from werkzeug.utils import cached_property
from werkzeug.wrappers import Response
-from werkzeug.routing import Map, Rule
-TEMPLATE_PATH = path.join(path.dirname(__file__), 'templates')
-STATIC_PATH = path.join(path.dirname(__file__), 'static')
-ALLOWED_SCHEMES = frozenset(['http', 'https', 'ftp', 'ftps'])
-URL_CHARS = 'abcdefghijkmpqrstuvwxyzABCDEFGHIJKLMNPQRST23456789'
+TEMPLATE_PATH = path.join(path.dirname(__file__), "templates")
+STATIC_PATH = path.join(path.dirname(__file__), "static")
+ALLOWED_SCHEMES = frozenset(["http", "https", "ftp", "ftps"])
+URL_CHARS = "abcdefghijkmpqrstuvwxyzABCDEFGHIJKLMNPQRST23456789"
local = Local()
local_manager = LocalManager([local])
-application = local('application')
+application = local("application")
-url_map = Map([Rule('/static/<file>', endpoint='static', build_only=True)])
+url_map = Map([Rule("/static/<file>", endpoint="static", build_only=True)])
jinja_env = Environment(loader=FileSystemLoader(TEMPLATE_PATH))
def expose(rule, **kw):
def decorate(f):
- kw['endpoint'] = f.__name__
+ kw["endpoint"] = f.__name__
url_map.add(Rule(rule, **kw))
return f
+
return decorate
+
def url_for(endpoint, _external=False, **values):
return local.url_adapter.build(endpoint, values, force_external=_external)
-jinja_env.globals['url_for'] = url_for
+
+
+jinja_env.globals["url_for"] = url_for
+
def render_template(template, **context):
- return Response(jinja_env.get_template(template).render(**context),
- mimetype='text/html')
+ return Response(
+ jinja_env.get_template(template).render(**context), mimetype="text/html"
+ )
+
def validate_url(url):
return url_parse(url)[0] in ALLOWED_SCHEMES
+
def get_random_uid():
- return ''.join(sample(URL_CHARS, randrange(3, 9)))
+ return "".join(sample(URL_CHARS, randrange(3, 9)))
-class Pagination(object):
+class Pagination(object):
def __init__(self, results, per_page, page, endpoint):
self.results = results
self.per_page = per_page
@@ -56,7 +69,11 @@ class Pagination(object):
@cached_property
def entries(self):
- return self.results[((self.page - 1) * self.per_page):(((self.page - 1) * self.per_page)+self.per_page)]
+ return self.results[
+ ((self.page - 1) * self.per_page) : (
+ ((self.page - 1) * self.per_page) + self.per_page
+ )
+ ]
has_previous = property(lambda self: self.page > 1)
has_next = property(lambda self: self.page < self.pages)
diff --git a/examples/couchy/views.py b/examples/couchy/views.py
index 39c8ea29..c1547e7d 100644
--- a/examples/couchy/views.py
+++ b/examples/couchy/views.py
@@ -1,61 +1,73 @@
-from werkzeug.utils import redirect
from werkzeug.exceptions import NotFound
-from couchy.utils import render_template, expose, \
- validate_url, url_for, Pagination
-from couchy.models import URL
+from werkzeug.utils import redirect
+from .models import URL
+from .utils import expose
+from .utils import Pagination
+from .utils import render_template
+from .utils import url_for
+from .utils import validate_url
-@expose('/')
+
+@expose("/")
def new(request):
- error = url = ''
- if request.method == 'POST':
- url = request.form.get('url')
- alias = request.form.get('alias')
+ error = url = ""
+ if request.method == "POST":
+ url = request.form.get("url")
+ alias = request.form.get("alias")
if not validate_url(url):
error = "I'm sorry but you cannot shorten this URL."
elif alias:
if len(alias) > 140:
- error = 'Your alias is too long'
- elif '/' in alias:
- error = 'Your alias might not include a slash'
+ error = "Your alias is too long"
+ elif "/" in alias:
+ error = "Your alias might not include a slash"
elif URL.load(alias):
- error = 'The alias you have requested exists already'
+ error = "The alias you have requested exists already"
if not error:
- url = URL(target=url, public='private' not in request.form, shorty_id=alias if alias else None)
+ url = URL(
+ target=url,
+ public="private" not in request.form,
+ shorty_id=alias if alias else None,
+ )
url.store()
uid = url.id
- return redirect(url_for('display', uid=uid))
- return render_template('new.html', error=error, url=url)
+ return redirect(url_for("display", uid=uid))
+ return render_template("new.html", error=error, url=url)
+
-@expose('/display/<uid>')
+@expose("/display/<uid>")
def display(request, uid):
url = URL.load(uid)
if not url:
raise NotFound()
- return render_template('display.html', url=url)
+ return render_template("display.html", url=url)
-@expose('/u/<uid>')
+
+@expose("/u/<uid>")
def link(request, uid):
url = URL.load(uid)
if not url:
raise NotFound()
return redirect(url.target, 301)
-@expose('/list/', defaults={'page': 1})
-@expose('/list/<int:page>')
+
+@expose("/list/", defaults={"page": 1})
+@expose("/list/<int:page>")
def list(request, page):
def wrap(doc):
data = doc.value
- data['_id'] = doc.id
+ data["_id"] = doc.id
return URL.wrap(data)
- code = '''function(doc) { if (doc.public){ map([doc._id], doc); }}'''
+ code = """function(doc) { if (doc.public){ map([doc._id], doc); }}"""
docResults = URL.query(code)
results = [wrap(doc) for doc in docResults]
- pagination = Pagination(results, 1, page, 'list')
+ pagination = Pagination(results, 1, page, "list")
if pagination.page > 1 and not pagination.entries:
raise NotFound()
- return render_template('list.html', pagination=pagination)
+ return render_template("list.html", pagination=pagination)
+
def not_found(request):
- return render_template('not_found.html')
+ return render_template("not_found.html")
diff --git a/examples/cupoftee/__init__.py b/examples/cupoftee/__init__.py
index f4a9d04d..184c5d0d 100644
--- a/examples/cupoftee/__init__.py
+++ b/examples/cupoftee/__init__.py
@@ -8,4 +8,4 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-from cupoftee.application import make_app
+from .application import make_app
diff --git a/examples/cupoftee/application.py b/examples/cupoftee/application.py
index 140b6b74..540e3f59 100644
--- a/examples/cupoftee/application.py
+++ b/examples/cupoftee/application.py
@@ -9,44 +9,58 @@
:license: BSD-3-Clause
"""
import time
-
-from jinja2 import Environment, PackageLoader
from os import path
from threading import Thread
-from cupoftee.db import Database
-from cupoftee.network import ServerBrowser
+from jinja2 import Environment
+from jinja2 import PackageLoader
+from werkzeug.exceptions import HTTPException
+from werkzeug.exceptions import NotFound
from werkzeug.middleware.shared_data import SharedDataMiddleware
-from werkzeug.wrappers import Request, Response
-from werkzeug.exceptions import HTTPException, NotFound
-from werkzeug.routing import Map, Rule
+from werkzeug.routing import Map
+from werkzeug.routing import Rule
+from werkzeug.wrappers import Request
+from werkzeug.wrappers import Response
+
+from .db import Database
+from .network import ServerBrowser
-templates = path.join(path.dirname(__file__), 'templates')
+templates = path.join(path.dirname(__file__), "templates")
pages = {}
-url_map = Map([Rule('/shared/<file>', endpoint='shared')])
+url_map = Map([Rule("/shared/<file>", endpoint="shared")])
def make_app(database, interval=120):
- return SharedDataMiddleware(Cup(database, interval), {
- '/shared': path.join(path.dirname(__file__), 'shared')
- })
+ return SharedDataMiddleware(
+ Cup(database, interval),
+ {"/shared": path.join(path.dirname(__file__), "shared")},
+ )
class PageMeta(type):
-
def __init__(cls, name, bases, d):
type.__init__(cls, name, bases, d)
- if d.get('url_rule') is not None:
+ if d.get("url_rule") is not None:
pages[cls.identifier] = cls
- url_map.add(Rule(cls.url_rule, endpoint=cls.identifier,
- **cls.url_arguments))
+ url_map.add(
+ Rule(cls.url_rule, endpoint=cls.identifier, **cls.url_arguments)
+ )
identifier = property(lambda self: self.__name__.lower())
-class Page(object):
- __metaclass__ = PageMeta
+def _with_metaclass(meta, *bases):
+ """Create a base class with a metaclass."""
+
+ class metaclass(type):
+ def __new__(metacls, name, this_bases, d):
+ return meta(name, bases, d)
+
+ return type.__new__(metaclass, "temporary_class", (), {})
+
+
+class Page(_with_metaclass(PageMeta, object)):
url_arguments = {}
def __init__(self, cup, request, url_adapter):
@@ -62,17 +76,16 @@ class Page(object):
def render_template(self, template=None):
if template is None:
- template = self.__class__.identifier + '.html'
+ template = self.__class__.identifier + ".html"
context = dict(self.__dict__)
context.update(url_for=self.url_for, self=self)
return self.cup.render_template(template, context)
def get_response(self):
- return Response(self.render_template(), mimetype='text/html')
+ return Response(self.render_template(), mimetype="text/html")
class Cup(object):
-
def __init__(self, database, interval=120):
self.jinja_env = Environment(loader=PackageLoader("cupoftee"), autoescape=True)
self.interval = interval
@@ -111,4 +124,5 @@ class Cup(object):
template = self.jinja_env.get_template(name)
return template.render(context)
+
from cupoftee.pages import MissingPage
diff --git a/examples/cupoftee/db.py b/examples/cupoftee/db.py
index 9db97c3d..7f041220 100644
--- a/examples/cupoftee/db.py
+++ b/examples/cupoftee/db.py
@@ -9,8 +9,9 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
+from pickle import dumps
+from pickle import loads
from threading import Lock
-from pickle import dumps, loads
try:
import dbm
@@ -19,10 +20,9 @@ except ImportError:
class Database(object):
-
def __init__(self, filename):
self.filename = filename
- self._fs = dbm.open(filename, 'cf')
+ self._fs = dbm.open(filename, "cf")
self._local = {}
self._lock = Lock()
@@ -43,7 +43,7 @@ class Database(object):
def __delitem__(self, key, value):
with self._lock:
self._local.pop(key, None)
- if self._fs.has_key(key):
+ if key in self._fs:
del self._fs[key]
def __del__(self):
@@ -75,5 +75,5 @@ class Database(object):
try:
self.sync()
self._fs.close()
- except:
+ except Exception:
pass
diff --git a/examples/cupoftee/network.py b/examples/cupoftee/network.py
index 0e472c1b..74c775aa 100644
--- a/examples/cupoftee/network.py
+++ b/examples/cupoftee/network.py
@@ -9,9 +9,10 @@
:license: BSD-3-Clause
"""
import socket
-from math import log
from datetime import datetime
-from cupoftee.utils import unicodecmp
+from math import log
+
+from .utils import unicodecmp
class ServerError(Exception):
@@ -31,15 +32,14 @@ class Syncable(object):
class ServerBrowser(Syncable):
-
def __init__(self, cup):
self.cup = cup
- self.servers = cup.db.setdefault('servers', dict)
+ self.servers = cup.db.setdefault("servers", dict)
def _sync(self):
to_delete = set(self.servers)
for x in range(1, 17):
- addr = ('master%d.teeworlds.com' % x, 8300)
+ addr = ("master%d.teeworlds.com" % x, 8300)
print(addr)
try:
self._sync_master(addr, to_delete)
@@ -48,20 +48,22 @@ class ServerBrowser(Syncable):
for server_id in to_delete:
self.servers.pop(server_id, None)
if not self.servers:
- raise IOError('no servers found')
+ raise IOError("no servers found")
self.cup.db.sync()
def _sync_master(self, addr, to_delete):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.settimeout(5)
- s.sendto(b'\x20\x00\x00\x00\x00\x48\xff\xff\xff\xffreqt', addr)
+ s.sendto(b"\x20\x00\x00\x00\x00\x48\xff\xff\xff\xffreqt", addr)
data = s.recvfrom(1024)[0][14:]
s.close()
for n in range(0, len(data) // 6):
- addr = ('.'.join(map(str, map(ord, data[n * 6:n * 6 + 4]))),
- ord(data[n * 6 + 5]) * 256 + ord(data[n * 6 + 4]))
- server_id = '%s:%d' % addr
+ addr = (
+ ".".join(map(str, map(ord, data[n * 6 : n * 6 + 4]))),
+ ord(data[n * 6 + 5]) * 256 + ord(data[n * 6 + 4]),
+ )
+ server_id = "%s:%d" % addr
if server_id in self.servers:
if not self.servers[server_id].sync():
continue
@@ -74,31 +76,31 @@ class ServerBrowser(Syncable):
class Server(Syncable):
-
def __init__(self, addr, server_id):
self.addr = addr
self.id = server_id
self.players = []
if not self.sync():
- raise ServerError('server not responding in time')
+ raise ServerError("server not responding in time")
def _sync(self):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.settimeout(1)
- s.sendto(b'\xff\xff\xff\xff\xff\xff\xff\xff\xff\xffgief', self.addr)
- bits = s.recvfrom(1024)[0][14:].split(b'\x00')
+ s.sendto(b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xffgief", self.addr)
+ bits = s.recvfrom(1024)[0][14:].split(b"\x00")
s.close()
self.version, server_name, map_name = bits[:3]
- self.name = server_name.decode('latin1')
- self.map = map_name.decode('latin1')
+ self.name = server_name.decode("latin1")
+ self.map = map_name.decode("latin1")
self.gametype = bits[3]
- self.flags, self.progression, player_count, \
- self.max_players = map(int, bits[4:8])
+ self.flags, self.progression, player_count, self.max_players = map(
+ int, bits[4:8]
+ )
# sync the player stats
players = dict((p.name, p) for p in self.players)
for i in range(player_count):
- name = bits[8 + i * 2].decode('latin1')
+ name = bits[8 + i * 2].decode("latin1")
score = int(bits[9 + i * 2])
# update existing player
@@ -112,7 +114,7 @@ class Server(Syncable):
for player in players.values():
try:
self.players.remove(player)
- except:
+ except Exception:
pass
# sort the player list and count them
@@ -124,7 +126,6 @@ class Server(Syncable):
class Player(object):
-
def __init__(self, server, name, score):
self.server = server
self.name = name
diff --git a/examples/cupoftee/pages.py b/examples/cupoftee/pages.py
index 7d799c25..c1a823b7 100644
--- a/examples/cupoftee/pages.py
+++ b/examples/cupoftee/pages.py
@@ -10,41 +10,42 @@
"""
from functools import reduce
-from werkzeug.utils import redirect
from werkzeug.exceptions import NotFound
-from cupoftee.application import Page
-from cupoftee.utils import unicodecmp
+from werkzeug.utils import redirect
+
+from .application import Page
+from .utils import unicodecmp
class ServerList(Page):
- url_rule = '/'
+ url_rule = "/"
def order_link(self, name, title):
- cls = ''
- link = '?order_by=' + name
+ cls = ""
+ link = "?order_by=" + name
desc = False
if name == self.order_by:
desc = not self.order_desc
- cls = ' class="%s"' % ('down' if desc else 'up')
+ cls = ' class="%s"' % ("down" if desc else "up")
if desc:
- link += '&amp;dir=desc'
+ link += "&amp;dir=desc"
return '<a href="%s"%s>%s</a>' % (link, cls, title)
def process(self):
- self.order_by = self.request.args.get('order_by') or 'name'
+ self.order_by = self.request.args.get("order_by") or "name"
sort_func = {
- 'name': lambda x: x,
- 'map': lambda x: x.map,
- 'gametype': lambda x: x.gametype,
- 'players': lambda x: x.player_count,
- 'progression': lambda x: x.progression,
+ "name": lambda x: x,
+ "map": lambda x: x.map,
+ "gametype": lambda x: x.gametype,
+ "players": lambda x: x.player_count,
+ "progression": lambda x: x.progression,
}.get(self.order_by)
if sort_func is None:
- return redirect(self.url_for('serverlist'))
+ return redirect(self.url_for("serverlist"))
self.servers = self.cup.master.servers.values()
self.servers.sort(key=sort_func)
- if self.request.args.get('dir') == 'desc':
+ if self.request.args.get("dir") == "desc":
self.servers.reverse()
self.order_desc = True
else:
@@ -55,7 +56,7 @@ class ServerList(Page):
class Server(Page):
- url_rule = '/server/<id>'
+ url_rule = "/server/<id>"
def process(self, id):
try:
@@ -65,10 +66,10 @@ class Server(Page):
class Search(Page):
- url_rule = '/search'
+ url_rule = "/search"
def process(self):
- self.user = self.request.args.get('user')
+ self.user = self.request.args.get("user")
if self.user:
self.results = []
for server in self.cup.master.servers.values():
@@ -78,7 +79,6 @@ class Search(Page):
class MissingPage(Page):
-
def get_response(self):
response = super(MissingPage, self).get_response()
response.status_code = 404
diff --git a/examples/cupoftee/utils.py b/examples/cupoftee/utils.py
index f13e2f3d..91717e85 100644
--- a/examples/cupoftee/utils.py
+++ b/examples/cupoftee/utils.py
@@ -11,7 +11,7 @@
import re
-_sort_re = re.compile(r'\w+', re.UNICODE)
+_sort_re = re.compile(r"\w+", re.UNICODE)
def unicodecmp(a, b):
diff --git a/examples/httpbasicauth.py b/examples/httpbasicauth.py
index 5a774f3f..21d39300 100644
--- a/examples/httpbasicauth.py
+++ b/examples/httpbasicauth.py
@@ -10,12 +10,12 @@
:license: BSD-3-Clause
"""
from werkzeug.serving import run_simple
-from werkzeug.wrappers import Request, Response
+from werkzeug.wrappers import Request
+from werkzeug.wrappers import Response
class Application(object):
-
- def __init__(self, users, realm='login required'):
+ def __init__(self, users, realm="login required"):
self.users = users
self.realm = realm
@@ -23,12 +23,15 @@ class Application(object):
return username in self.users and self.users[username] == password
def auth_required(self, request):
- return Response('Could not verify your access level for that URL.\n'
- 'You have to login with proper credentials', 401,
- {'WWW-Authenticate': 'Basic realm="%s"' % self.realm})
+ return Response(
+ "Could not verify your access level for that URL.\n"
+ "You have to login with proper credentials",
+ 401,
+ {"WWW-Authenticate": 'Basic realm="%s"' % self.realm},
+ )
def dispatch_request(self, request):
- return Response('Logged in as %s' % request.authorization.username)
+ return Response("Logged in as %s" % request.authorization.username)
def __call__(self, environ, start_response):
request = Request(environ)
@@ -40,6 +43,6 @@ class Application(object):
return response(environ, start_response)
-if __name__ == '__main__':
- application = Application({'user1': 'password', 'user2': 'password'})
- run_simple('localhost', 5000, application)
+if __name__ == "__main__":
+ application = Application({"user1": "password", "user2": "password"})
+ run_simple("localhost", 5000, application)
diff --git a/examples/i18nurls/__init__.py b/examples/i18nurls/__init__.py
index 393d270c..f5f5c6ed 100644
--- a/examples/i18nurls/__init__.py
+++ b/examples/i18nurls/__init__.py
@@ -1 +1 @@
-from i18nurls.application import Application as make_app
+from .application import Application as make_app
diff --git a/examples/i18nurls/application.py b/examples/i18nurls/application.py
index 54f12353..103f6001 100644
--- a/examples/i18nurls/application.py
+++ b/examples/i18nurls/application.py
@@ -1,33 +1,39 @@
-from jinja2 import Environment, PackageLoader
from os import path
-from werkzeug.wrappers import Request as _Request, BaseResponse
+
+from jinja2 import Environment
+from jinja2 import PackageLoader
+from werkzeug.exceptions import HTTPException
+from werkzeug.exceptions import NotFound
from werkzeug.routing import RequestRedirect
-from werkzeug.exceptions import HTTPException, NotFound
-from i18nurls.urls import map
+from werkzeug.wrappers import BaseResponse
+from werkzeug.wrappers import Request as _Request
+
+from .urls import map
-TEMPLATES = path.join(path.dirname(__file__), 'templates')
+TEMPLATES = path.join(path.dirname(__file__), "templates")
views = {}
def expose(name):
"""Register the function as view."""
+
def wrapped(f):
views[name] = f
return f
+
return wrapped
class Request(_Request):
-
def __init__(self, environ, urls):
super(Request, self).__init__(environ)
self.urls = urls
self.matched_url = None
def url_for(self, endpoint, **args):
- if not 'lang_code' in args:
- args['lang_code'] = self.language
- if endpoint == 'this':
+ if "lang_code" not in args:
+ args["lang_code"] = self.language
+ if endpoint == "this":
endpoint = self.matched_url[0]
tmp = self.matched_url[1].copy()
tmp.update(args)
@@ -45,12 +51,12 @@ class TemplateResponse(Response):
def __init__(self, template_name, **values):
self.template_name = template_name
self.template_values = values
- Response.__init__(self, mimetype='text/html')
+ Response.__init__(self, mimetype="text/html")
def __call__(self, environ, start_response):
- req = environ['werkzeug.request']
+ req = environ["werkzeug.request"]
values = self.template_values.copy()
- values['req'] = req
+ values["req"] = req
self.data = self.render_template(self.template_name, values)
return super(TemplateResponse, self).__call__(environ, start_response)
@@ -60,9 +66,9 @@ class TemplateResponse(Response):
class Application(object):
-
def __init__(self):
from i18nurls import views
+
self.not_found = views.page_not_found
def __call__(self, environ, start_response):
@@ -71,14 +77,14 @@ class Application(object):
try:
endpoint, args = urls.match(req.path)
req.matched_url = (endpoint, args)
- if endpoint == '#language_select':
+ if endpoint == "#language_select":
lng = req.accept_languages.best
- lng = lng and lng.split('-')[0].lower() or 'en'
- index_url = urls.build('index', {'lang_code': lng})
- resp = Response('Moved to %s' % index_url, status=302)
- resp.headers['Location'] = index_url
+ lng = lng and lng.split("-")[0].lower() or "en"
+ index_url = urls.build("index", {"lang_code": lng})
+ resp = Response("Moved to %s" % index_url, status=302)
+ resp.headers["Location"] = index_url
else:
- req.language = args.pop('lang_code', None)
+ req.language = args.pop("lang_code", None)
resp = views[endpoint](req, **args)
except NotFound:
resp = self.not_found(req)
diff --git a/examples/i18nurls/urls.py b/examples/i18nurls/urls.py
index 57ff68a2..3dd54a00 100644
--- a/examples/i18nurls/urls.py
+++ b/examples/i18nurls/urls.py
@@ -1,11 +1,18 @@
-from werkzeug.routing import Map, Rule, Submount
+from werkzeug.routing import Map
+from werkzeug.routing import Rule
+from werkzeug.routing import Submount
-map = Map([
- Rule('/', endpoint='#language_select'),
- Submount('/<string(length=2):lang_code>', [
- Rule('/', endpoint='index'),
- Rule('/about', endpoint='about'),
- Rule('/blog/', endpoint='blog/index'),
- Rule('/blog/<int:post_id>', endpoint='blog/show')
- ])
-])
+map = Map(
+ [
+ Rule("/", endpoint="#language_select"),
+ Submount(
+ "/<string(length=2):lang_code>",
+ [
+ Rule("/", endpoint="index"),
+ Rule("/about", endpoint="about"),
+ Rule("/blog/", endpoint="blog/index"),
+ Rule("/blog/<int:post_id>", endpoint="blog/show"),
+ ],
+ ),
+ ]
+)
diff --git a/examples/i18nurls/views.py b/examples/i18nurls/views.py
index 7b8bdb70..6c0ca895 100644
--- a/examples/i18nurls/views.py
+++ b/examples/i18nurls/views.py
@@ -1,22 +1,29 @@
-from i18nurls.application import TemplateResponse, Response, expose
+from .application import expose
+from .application import Response
+from .application import TemplateResponse
-@expose('index')
+@expose("index")
def index(req):
- return TemplateResponse('index.html', title='Index')
+ return TemplateResponse("index.html", title="Index")
-@expose('about')
+
+@expose("about")
def about(req):
- return TemplateResponse('about.html', title='About')
+ return TemplateResponse("about.html", title="About")
+
-@expose('blog/index')
+@expose("blog/index")
def blog_index(req):
- return TemplateResponse('blog.html', title='Blog Index', mode='index')
+ return TemplateResponse("blog.html", title="Blog Index", mode="index")
-@expose('blog/show')
+
+@expose("blog/show")
def blog_show(req, post_id):
- return TemplateResponse('blog.html', title='Blog Post #%d' % post_id,
- post_id=post_id, mode='show')
+ return TemplateResponse(
+ "blog.html", title="Blog Post #%d" % post_id, post_id=post_id, mode="show"
+ )
+
def page_not_found(req):
- return Response('<h1>Page Not Found</h1>', mimetype='text/html')
+ return Response("<h1>Page Not Found</h1>", mimetype="text/html")
diff --git a/examples/manage-coolmagic.py b/examples/manage-coolmagic.py
index 8059003f..f6abe80e 100755
--- a/examples/manage-coolmagic.py
+++ b/examples/manage-coolmagic.py
@@ -10,9 +10,10 @@
:license: BSD-3-Clause
"""
import click
-from coolmagic import make_app
from werkzeug.serving import run_simple
+from coolmagic import make_app
+
@click.group()
def cli():
@@ -20,36 +21,45 @@ def cli():
@cli.command()
-@click.option('-h', '--hostname', type=str, default='localhost', help="localhost")
-@click.option('-p', '--port', type=int, default=5000, help="5000")
-@click.option('--no-reloader', is_flag=True, default=False)
-@click.option('--debugger', is_flag=True)
-@click.option('--no-evalex', is_flag=True, default=False)
-@click.option('--threaded', is_flag=True)
-@click.option('--processes', type=int, default=1, help="1")
+@click.option("-h", "--hostname", type=str, default="localhost", help="localhost")
+@click.option("-p", "--port", type=int, default=5000, help="5000")
+@click.option("--no-reloader", is_flag=True, default=False)
+@click.option("--debugger", is_flag=True)
+@click.option("--no-evalex", is_flag=True, default=False)
+@click.option("--threaded", is_flag=True)
+@click.option("--processes", type=int, default=1, help="1")
def runserver(hostname, port, no_reloader, debugger, no_evalex, threaded, processes):
"""Start a new development server."""
app = make_app()
reloader = not no_reloader
evalex = not no_evalex
- run_simple(hostname, port, app,
- use_reloader=reloader, use_debugger=debugger,
- use_evalex=evalex, threaded=threaded, processes=processes)
+ run_simple(
+ hostname,
+ port,
+ app,
+ use_reloader=reloader,
+ use_debugger=debugger,
+ use_evalex=evalex,
+ threaded=threaded,
+ processes=processes,
+ )
@cli.command()
-@click.option('--no-ipython', is_flag=True, default=False)
+@click.option("--no-ipython", is_flag=True, default=False)
def shell(no_ipython):
"""Start a new interactive python session."""
- banner = 'Interactive Werkzeug Shell'
+ banner = "Interactive Werkzeug Shell"
namespace = dict()
if not no_ipython:
try:
try:
from IPython.frontend.terminal.embed import InteractiveShellEmbed
+
sh = InteractiveShellEmbed.instance(banner1=banner)
except ImportError:
from IPython.Shell import IPShellEmbed
+
sh = IPShellEmbed(banner=banner)
except ImportError:
pass
@@ -57,7 +67,9 @@ def shell(no_ipython):
sh(local_ns=namespace)
return
from code import interact
+
interact(banner, local=namespace)
-if __name__ == '__main__':
+
+if __name__ == "__main__":
cli()
diff --git a/examples/manage-couchy.py b/examples/manage-couchy.py
index 1bbf4a67..8a00a400 100755
--- a/examples/manage-couchy.py
+++ b/examples/manage-couchy.py
@@ -5,17 +5,15 @@ from werkzeug.serving import run_simple
def make_app():
from couchy.application import Couchy
- return Couchy('http://localhost:5984')
+
+ return Couchy("http://localhost:5984")
def make_shell():
from couchy import models, utils
+
application = make_app()
- return {
- "application": application,
- "models": models,
- "utils": utils,
- }
+ return {"application": application, "models": models, "utils": utils}
@click.group()
@@ -26,40 +24,50 @@ def cli():
@cli.command()
def initdb():
from couchy.application import Couchy
- Couchy('http://localhost:5984').init_database()
+
+ Couchy("http://localhost:5984").init_database()
@cli.command()
-@click.option('-h', '--hostname', type=str, default='localhost', help="localhost")
-@click.option('-p', '--port', type=int, default=5000, help="5000")
-@click.option('--no-reloader', is_flag=True, default=False)
-@click.option('--debugger', is_flag=True)
-@click.option('--no-evalex', is_flag=True, default=False)
-@click.option('--threaded', is_flag=True)
-@click.option('--processes', type=int, default=1, help="1")
+@click.option("-h", "--hostname", type=str, default="localhost", help="localhost")
+@click.option("-p", "--port", type=int, default=5000, help="5000")
+@click.option("--no-reloader", is_flag=True, default=False)
+@click.option("--debugger", is_flag=True)
+@click.option("--no-evalex", is_flag=True, default=False)
+@click.option("--threaded", is_flag=True)
+@click.option("--processes", type=int, default=1, help="1")
def runserver(hostname, port, no_reloader, debugger, no_evalex, threaded, processes):
"""Start a new development server."""
app = make_app()
reloader = not no_reloader
evalex = not no_evalex
- run_simple(hostname, port, app,
- use_reloader=reloader, use_debugger=debugger,
- use_evalex=evalex, threaded=threaded, processes=processes)
+ run_simple(
+ hostname,
+ port,
+ app,
+ use_reloader=reloader,
+ use_debugger=debugger,
+ use_evalex=evalex,
+ threaded=threaded,
+ processes=processes,
+ )
@cli.command()
-@click.option('--no-ipython', is_flag=True, default=False)
+@click.option("--no-ipython", is_flag=True, default=False)
def shell(no_ipython):
"""Start a new interactive python session."""
- banner = 'Interactive Werkzeug Shell'
+ banner = "Interactive Werkzeug Shell"
namespace = make_shell()
if not no_ipython:
try:
try:
from IPython.frontend.terminal.embed import InteractiveShellEmbed
+
sh = InteractiveShellEmbed.instance(banner1=banner)
except ImportError:
from IPython.Shell import IPShellEmbed
+
sh = IPShellEmbed(banner=banner)
except ImportError:
pass
@@ -67,7 +75,9 @@ def shell(no_ipython):
sh(local_ns=namespace)
return
from code import interact
+
interact(banner, local=namespace)
-if __name__ == '__main__':
+
+if __name__ == "__main__":
cli()
diff --git a/examples/manage-cupoftee.py b/examples/manage-cupoftee.py
index e905f06b..a787f324 100755
--- a/examples/manage-cupoftee.py
+++ b/examples/manage-cupoftee.py
@@ -14,7 +14,8 @@ from werkzeug.serving import run_simple
def make_app():
from cupoftee import make_app
- return make_app('/tmp/cupoftee.db')
+
+ return make_app("/tmp/cupoftee.db")
@click.group()
@@ -23,20 +24,27 @@ def cli():
@cli.command()
-@click.option('-h', '--hostname', type=str, default='localhost', help="localhost")
-@click.option('-p', '--port', type=int, default=5000, help="5000")
-@click.option('--reloader', is_flag=True, default=False)
-@click.option('--debugger', is_flag=True)
-@click.option('--evalex', is_flag=True, default=False)
-@click.option('--threaded', is_flag=True)
-@click.option('--processes', type=int, default=1, help="1")
+@click.option("-h", "--hostname", type=str, default="localhost", help="localhost")
+@click.option("-p", "--port", type=int, default=5000, help="5000")
+@click.option("--reloader", is_flag=True, default=False)
+@click.option("--debugger", is_flag=True)
+@click.option("--evalex", is_flag=True, default=False)
+@click.option("--threaded", is_flag=True)
+@click.option("--processes", type=int, default=1, help="1")
def runserver(hostname, port, reloader, debugger, evalex, threaded, processes):
"""Start a new development server."""
app = make_app()
- run_simple(hostname, port, app,
- use_reloader=reloader, use_debugger=debugger,
- use_evalex=evalex, threaded=threaded, processes=processes)
-
-
-if __name__ == '__main__':
+ run_simple(
+ hostname,
+ port,
+ app,
+ use_reloader=reloader,
+ use_debugger=debugger,
+ use_evalex=evalex,
+ threaded=threaded,
+ processes=processes,
+ )
+
+
+if __name__ == "__main__":
cli()
diff --git a/examples/manage-i18nurls.py b/examples/manage-i18nurls.py
index 04a5068e..a2b746a2 100755
--- a/examples/manage-i18nurls.py
+++ b/examples/manage-i18nurls.py
@@ -10,9 +10,10 @@
:license: BSD-3-Clause
"""
import click
-from i18nurls import make_app
from werkzeug.serving import run_simple
+from i18nurls import make_app
+
@click.group()
def cli():
@@ -20,36 +21,45 @@ def cli():
@cli.command()
-@click.option('-h', '--hostname', type=str, default='localhost', help="localhost")
-@click.option('-p', '--port', type=int, default=5000, help="5000")
-@click.option('--no-reloader', is_flag=True, default=False)
-@click.option('--debugger', is_flag=True)
-@click.option('--no-evalex', is_flag=True, default=False)
-@click.option('--threaded', is_flag=True)
-@click.option('--processes', type=int, default=1, help="1")
+@click.option("-h", "--hostname", type=str, default="localhost", help="localhost")
+@click.option("-p", "--port", type=int, default=5000, help="5000")
+@click.option("--no-reloader", is_flag=True, default=False)
+@click.option("--debugger", is_flag=True)
+@click.option("--no-evalex", is_flag=True, default=False)
+@click.option("--threaded", is_flag=True)
+@click.option("--processes", type=int, default=1, help="1")
def runserver(hostname, port, no_reloader, debugger, no_evalex, threaded, processes):
"""Start a new development server."""
app = make_app()
reloader = not no_reloader
evalex = not no_evalex
- run_simple(hostname, port, app,
- use_reloader=reloader, use_debugger=debugger,
- use_evalex=evalex, threaded=threaded, processes=processes)
+ run_simple(
+ hostname,
+ port,
+ app,
+ use_reloader=reloader,
+ use_debugger=debugger,
+ use_evalex=evalex,
+ threaded=threaded,
+ processes=processes,
+ )
@cli.command()
-@click.option('--no-ipython', is_flag=True, default=False)
+@click.option("--no-ipython", is_flag=True, default=False)
def shell(no_ipython):
"""Start a new interactive python session."""
- banner = 'Interactive Werkzeug Shell'
+ banner = "Interactive Werkzeug Shell"
namespace = dict()
if not no_ipython:
try:
try:
from IPython.frontend.terminal.embed import InteractiveShellEmbed
+
sh = InteractiveShellEmbed.instance(banner1=banner)
except ImportError:
from IPython.Shell import IPShellEmbed
+
sh = IPShellEmbed(banner=banner)
except ImportError:
pass
@@ -57,7 +67,9 @@ def shell(no_ipython):
sh(local_ns=namespace)
return
from code import interact
+
interact(banner, local=namespace)
-if __name__ == '__main__':
+
+if __name__ == "__main__":
cli()
diff --git a/examples/manage-plnt.py b/examples/manage-plnt.py
index 8ce5b229..80a03f78 100755
--- a/examples/manage-plnt.py
+++ b/examples/manage-plnt.py
@@ -9,16 +9,18 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-import click
import os
+
+import click
from werkzeug.serving import run_simple
def make_app():
"""Helper function that creates a plnt app."""
from plnt import Plnt
- database_uri = os.environ.get('PLNT_DATABASE_URI')
- app = Plnt(database_uri or 'sqlite:////tmp/plnt.db')
+
+ database_uri = os.environ.get("PLNT_DATABASE_URI")
+ app = Plnt(database_uri or "sqlite:////tmp/plnt.db")
app.bind_to_context()
return app
@@ -32,62 +34,88 @@ def cli():
def initdb():
"""Initialize the database"""
from plnt.database import Blog, session
+
make_app().init_database()
# and now fill in some python blogs everybody should read (shamelessly
# added my own blog too)
blogs = [
- Blog('Armin Ronacher', 'http://lucumr.pocoo.org/',
- 'http://lucumr.pocoo.org/cogitations/feed/'),
- Blog('Georg Brandl', 'http://pyside.blogspot.com/',
- 'http://pyside.blogspot.com/feeds/posts/default'),
- Blog('Ian Bicking', 'http://blog.ianbicking.org/',
- 'http://blog.ianbicking.org/feed/'),
- Blog('Amir Salihefendic', 'http://amix.dk/',
- 'http://feeds.feedburner.com/amixdk'),
- Blog('Christopher Lenz', 'http://www.cmlenz.net/blog/',
- 'http://www.cmlenz.net/blog/atom.xml'),
- Blog('Frederick Lundh', 'http://online.effbot.org/',
- 'http://online.effbot.org/rss.xml')
+ Blog(
+ "Armin Ronacher",
+ "http://lucumr.pocoo.org/",
+ "http://lucumr.pocoo.org/cogitations/feed/",
+ ),
+ Blog(
+ "Georg Brandl",
+ "http://pyside.blogspot.com/",
+ "http://pyside.blogspot.com/feeds/posts/default",
+ ),
+ Blog(
+ "Ian Bicking",
+ "http://blog.ianbicking.org/",
+ "http://blog.ianbicking.org/feed/",
+ ),
+ Blog(
+ "Amir Salihefendic", "http://amix.dk/", "http://feeds.feedburner.com/amixdk"
+ ),
+ Blog(
+ "Christopher Lenz",
+ "http://www.cmlenz.net/blog/",
+ "http://www.cmlenz.net/blog/atom.xml",
+ ),
+ Blog(
+ "Frederick Lundh",
+ "http://online.effbot.org/",
+ "http://online.effbot.org/rss.xml",
+ ),
]
# okay. got tired here. if someone feels that he is missing, drop me
# a line ;-)
for blog in blogs:
session.add(blog)
session.commit()
- click.echo('Initialized database, now run manage-plnt.py sync to get the posts')
+ click.echo("Initialized database, now run manage-plnt.py sync to get the posts")
@cli.command()
-@click.option('-h', '--hostname', type=str, default='localhost', help="localhost")
-@click.option('-p', '--port', type=int, default=5000, help="5000")
-@click.option('--no-reloader', is_flag=True, default=False)
-@click.option('--debugger', is_flag=True)
-@click.option('--no-evalex', is_flag=True, default=False)
-@click.option('--threaded', is_flag=True)
-@click.option('--processes', type=int, default=1, help="1")
+@click.option("-h", "--hostname", type=str, default="localhost", help="localhost")
+@click.option("-p", "--port", type=int, default=5000, help="5000")
+@click.option("--no-reloader", is_flag=True, default=False)
+@click.option("--debugger", is_flag=True)
+@click.option("--no-evalex", is_flag=True, default=False)
+@click.option("--threaded", is_flag=True)
+@click.option("--processes", type=int, default=1, help="1")
def runserver(hostname, port, no_reloader, debugger, no_evalex, threaded, processes):
"""Start a new development server."""
app = make_app()
reloader = not no_reloader
evalex = not no_evalex
- run_simple(hostname, port, app,
- use_reloader=reloader, use_debugger=debugger,
- use_evalex=evalex, threaded=threaded, processes=processes)
+ run_simple(
+ hostname,
+ port,
+ app,
+ use_reloader=reloader,
+ use_debugger=debugger,
+ use_evalex=evalex,
+ threaded=threaded,
+ processes=processes,
+ )
@cli.command()
-@click.option('--no-ipython', is_flag=True, default=False)
+@click.option("--no-ipython", is_flag=True, default=False)
def shell(no_ipython):
"""Start a new interactive python session."""
- banner = 'Interactive Werkzeug Shell'
- namespace = {'app': make_app()}
+ banner = "Interactive Werkzeug Shell"
+ namespace = {"app": make_app()}
if not no_ipython:
try:
try:
from IPython.frontend.terminal.embed import InteractiveShellEmbed
+
sh = InteractiveShellEmbed.instance(banner1=banner)
except ImportError:
from IPython.Shell import IPShellEmbed
+
sh = IPShellEmbed(banner=banner)
except ImportError:
pass
@@ -95,6 +123,7 @@ def shell(no_ipython):
sh(local_ns=namespace)
return
from code import interact
+
interact(banner, local=namespace)
@@ -102,8 +131,10 @@ def shell(no_ipython):
def sync():
"""Sync the blogs in the planet. Call this from a cronjob."""
from plnt.sync import sync
+
make_app().bind_to_context()
sync()
-if __name__ == '__main__':
+
+if __name__ == "__main__":
cli()
diff --git a/examples/manage-shorty.py b/examples/manage-shorty.py
index d80eccab..80fb3731 100755
--- a/examples/manage-shorty.py
+++ b/examples/manage-shorty.py
@@ -1,24 +1,23 @@
#!/usr/bin/env python
-import click
import os
import tempfile
+
+import click
from werkzeug.serving import run_simple
def make_app():
from shorty.application import Shorty
+
filename = os.path.join(tempfile.gettempdir(), "shorty.db")
- return Shorty('sqlite:///{0}'.format(filename))
+ return Shorty("sqlite:///{0}".format(filename))
def make_shell():
from shorty import models, utils
+
application = make_app()
- return {
- "application": application,
- "models": models,
- "utils": utils,
- }
+ return {"application": application, "models": models, "utils": utils}
@click.group()
@@ -32,36 +31,45 @@ def initdb():
@cli.command()
-@click.option('-h', '--hostname', type=str, default='localhost', help="localhost")
-@click.option('-p', '--port', type=int, default=5000, help="5000")
-@click.option('--no-reloader', is_flag=True, default=False)
-@click.option('--debugger', is_flag=True)
-@click.option('--no-evalex', is_flag=True, default=False)
-@click.option('--threaded', is_flag=True)
-@click.option('--processes', type=int, default=1, help="1")
+@click.option("-h", "--hostname", type=str, default="localhost", help="localhost")
+@click.option("-p", "--port", type=int, default=5000, help="5000")
+@click.option("--no-reloader", is_flag=True, default=False)
+@click.option("--debugger", is_flag=True)
+@click.option("--no-evalex", is_flag=True, default=False)
+@click.option("--threaded", is_flag=True)
+@click.option("--processes", type=int, default=1, help="1")
def runserver(hostname, port, no_reloader, debugger, no_evalex, threaded, processes):
"""Start a new development server."""
app = make_app()
reloader = not no_reloader
evalex = not no_evalex
- run_simple(hostname, port, app,
- use_reloader=reloader, use_debugger=debugger,
- use_evalex=evalex, threaded=threaded, processes=processes)
+ run_simple(
+ hostname,
+ port,
+ app,
+ use_reloader=reloader,
+ use_debugger=debugger,
+ use_evalex=evalex,
+ threaded=threaded,
+ processes=processes,
+ )
@cli.command()
-@click.option('--no-ipython', is_flag=True, default=False)
+@click.option("--no-ipython", is_flag=True, default=False)
def shell(no_ipython):
"""Start a new interactive python session."""
- banner = 'Interactive Werkzeug Shell'
+ banner = "Interactive Werkzeug Shell"
namespace = make_shell()
if not no_ipython:
try:
try:
from IPython.frontend.terminal.embed import InteractiveShellEmbed
+
sh = InteractiveShellEmbed.instance(banner1=banner)
except ImportError:
from IPython.Shell import IPShellEmbed
+
sh = IPShellEmbed(banner=banner)
except ImportError:
pass
@@ -69,7 +77,9 @@ def shell(no_ipython):
sh(local_ns=namespace)
return
from code import interact
+
interact(banner, local=namespace)
-if __name__ == '__main__':
+
+if __name__ == "__main__":
cli()
diff --git a/examples/manage-simplewiki.py b/examples/manage-simplewiki.py
index 3e23ceb1..5fd66b85 100755
--- a/examples/manage-simplewiki.py
+++ b/examples/manage-simplewiki.py
@@ -9,26 +9,26 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-import click
import os
+
+import click
from werkzeug.serving import run_simple
def make_wiki():
"""Helper function that creates a new wiki instance."""
from simplewiki import SimpleWiki
- database_uri = os.environ.get('SIMPLEWIKI_DATABASE_URI')
- return SimpleWiki(database_uri or 'sqlite:////tmp/simplewiki.db')
+
+ database_uri = os.environ.get("SIMPLEWIKI_DATABASE_URI")
+ return SimpleWiki(database_uri or "sqlite:////tmp/simplewiki.db")
def make_shell():
from simplewiki import database
+
wiki = make_wiki()
wiki.bind_to_context()
- return {
- 'wiki': wiki,
- 'db': database
- }
+ return {"wiki": wiki, "db": database}
@click.group()
@@ -42,36 +42,45 @@ def initdb():
@cli.command()
-@click.option('-h', '--hostname', type=str, default='localhost', help="localhost")
-@click.option('-p', '--port', type=int, default=5000, help="5000")
-@click.option('--no-reloader', is_flag=True, default=False)
-@click.option('--debugger', is_flag=True)
-@click.option('--no-evalex', is_flag=True, default=False)
-@click.option('--threaded', is_flag=True)
-@click.option('--processes', type=int, default=1, help="1")
+@click.option("-h", "--hostname", type=str, default="localhost", help="localhost")
+@click.option("-p", "--port", type=int, default=5000, help="5000")
+@click.option("--no-reloader", is_flag=True, default=False)
+@click.option("--debugger", is_flag=True)
+@click.option("--no-evalex", is_flag=True, default=False)
+@click.option("--threaded", is_flag=True)
+@click.option("--processes", type=int, default=1, help="1")
def runserver(hostname, port, no_reloader, debugger, no_evalex, threaded, processes):
"""Start a new development server."""
app = make_wiki()
reloader = not no_reloader
evalex = not no_evalex
- run_simple(hostname, port, app,
- use_reloader=reloader, use_debugger=debugger,
- use_evalex=evalex, threaded=threaded, processes=processes)
+ run_simple(
+ hostname,
+ port,
+ app,
+ use_reloader=reloader,
+ use_debugger=debugger,
+ use_evalex=evalex,
+ threaded=threaded,
+ processes=processes,
+ )
@cli.command()
-@click.option('--no-ipython', is_flag=True, default=False)
+@click.option("--no-ipython", is_flag=True, default=False)
def shell(no_ipython):
"""Start a new interactive python session."""
- banner = 'Interactive Werkzeug Shell'
+ banner = "Interactive Werkzeug Shell"
namespace = make_shell()
if not no_ipython:
try:
try:
from IPython.frontend.terminal.embed import InteractiveShellEmbed
+
sh = InteractiveShellEmbed.instance(banner1=banner)
except ImportError:
from IPython.Shell import IPShellEmbed
+
sh = IPShellEmbed(banner=banner)
except ImportError:
pass
@@ -79,7 +88,9 @@ def shell(no_ipython):
sh(local_ns=namespace)
return
from code import interact
+
interact(banner, local=namespace)
-if __name__ == '__main__':
+
+if __name__ == "__main__":
cli()
diff --git a/examples/manage-webpylike.py b/examples/manage-webpylike.py
index 2e77714d..010bdcee 100755
--- a/examples/manage-webpylike.py
+++ b/examples/manage-webpylike.py
@@ -13,13 +13,16 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-import click
import os
import sys
-sys.path.append(os.path.join(os.path.dirname(__file__), 'webpylike'))
-from webpylike.example import app
+
+import click
from werkzeug.serving import run_simple
+from webpylike.example import app
+
+sys.path.append(os.path.join(os.path.dirname(__file__), "webpylike"))
+
@click.group()
def cli():
@@ -27,35 +30,44 @@ def cli():
@cli.command()
-@click.option('-h', '--hostname', type=str, default='localhost', help="localhost")
-@click.option('-p', '--port', type=int, default=5000, help="5000")
-@click.option('--no-reloader', is_flag=True, default=False)
-@click.option('--debugger', is_flag=True)
-@click.option('--no-evalex', is_flag=True, default=False)
-@click.option('--threaded', is_flag=True)
-@click.option('--processes', type=int, default=1, help="1")
+@click.option("-h", "--hostname", type=str, default="localhost", help="localhost")
+@click.option("-p", "--port", type=int, default=5000, help="5000")
+@click.option("--no-reloader", is_flag=True, default=False)
+@click.option("--debugger", is_flag=True)
+@click.option("--no-evalex", is_flag=True, default=False)
+@click.option("--threaded", is_flag=True)
+@click.option("--processes", type=int, default=1, help="1")
def runserver(hostname, port, no_reloader, debugger, no_evalex, threaded, processes):
"""Start a new development server."""
reloader = not no_reloader
evalex = not no_evalex
- run_simple(hostname, port, app,
- use_reloader=reloader, use_debugger=debugger,
- use_evalex=evalex, threaded=threaded, processes=processes)
+ run_simple(
+ hostname,
+ port,
+ app,
+ use_reloader=reloader,
+ use_debugger=debugger,
+ use_evalex=evalex,
+ threaded=threaded,
+ processes=processes,
+ )
@cli.command()
-@click.option('--no-ipython', is_flag=True, default=False)
+@click.option("--no-ipython", is_flag=True, default=False)
def shell(no_ipython):
"""Start a new interactive python session."""
- banner = 'Interactive Werkzeug Shell'
+ banner = "Interactive Werkzeug Shell"
namespace = dict()
if not no_ipython:
try:
try:
from IPython.frontend.terminal.embed import InteractiveShellEmbed
+
sh = InteractiveShellEmbed.instance(banner1=banner)
except ImportError:
from IPython.Shell import IPShellEmbed
+
sh = IPShellEmbed(banner=banner)
except ImportError:
pass
@@ -63,7 +75,9 @@ def shell(no_ipython):
sh(local_ns=namespace)
return
from code import interact
+
interact(banner, local=namespace)
-if __name__ == '__main__':
+
+if __name__ == "__main__":
cli()
diff --git a/examples/partial/complex_routing.py b/examples/partial/complex_routing.py
index 18ce7b0a..596d00e4 100644
--- a/examples/partial/complex_routing.py
+++ b/examples/partial/complex_routing.py
@@ -1,19 +1,43 @@
-from werkzeug.routing import Map, Rule, Subdomain, Submount, EndpointPrefix
+from werkzeug.routing import EndpointPrefix
+from werkzeug.routing import Map
+from werkzeug.routing import Rule
+from werkzeug.routing import Subdomain
+from werkzeug.routing import Submount
-m = Map([
- # Static URLs
- EndpointPrefix('static/', [
- Rule('/', endpoint='index'),
- Rule('/about', endpoint='about'),
- Rule('/help', endpoint='help'),
- ]),
- # Knowledge Base
- Subdomain('kb', [EndpointPrefix('kb/', [
- Rule('/', endpoint='index'),
- Submount('/browse', [
- Rule('/', endpoint='browse'),
- Rule('/<int:id>/', defaults={'page': 1}, endpoint='browse'),
- Rule('/<int:id>/<int:page>', endpoint='browse')
- ])
- ])])
-])
+m = Map(
+ [
+ # Static URLs
+ EndpointPrefix(
+ "static/",
+ [
+ Rule("/", endpoint="index"),
+ Rule("/about", endpoint="about"),
+ Rule("/help", endpoint="help"),
+ ],
+ ),
+ # Knowledge Base
+ Subdomain(
+ "kb",
+ [
+ EndpointPrefix(
+ "kb/",
+ [
+ Rule("/", endpoint="index"),
+ Submount(
+ "/browse",
+ [
+ Rule("/", endpoint="browse"),
+ Rule(
+ "/<int:id>/",
+ defaults={"page": 1},
+ endpoint="browse",
+ ),
+ Rule("/<int:id>/<int:page>", endpoint="browse"),
+ ],
+ ),
+ ],
+ )
+ ],
+ ),
+ ]
+)
diff --git a/examples/plnt/__init__.py b/examples/plnt/__init__.py
index f1bb94f7..7ade4b72 100644
--- a/examples/plnt/__init__.py
+++ b/examples/plnt/__init__.py
@@ -8,4 +8,4 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-from plnt.webapp import Plnt
+from .webapp import Plnt
diff --git a/examples/plnt/database.py b/examples/plnt/database.py
index c74abb84..78e4ab6a 100644
--- a/examples/plnt/database.py
+++ b/examples/plnt/database.py
@@ -8,62 +8,73 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-from sqlalchemy import MetaData, Table, Column, ForeignKey, \
- Integer, String, DateTime
-from sqlalchemy.orm import dynamic_loader, scoped_session, create_session, \
- mapper
-from plnt.utils import application, local_manager
+from sqlalchemy import Column
+from sqlalchemy import DateTime
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy.orm import create_session
+from sqlalchemy.orm import dynamic_loader
+from sqlalchemy.orm import mapper
+from sqlalchemy.orm import scoped_session
+
+from .utils import application
+from .utils import local_manager
def new_db_session():
- return create_session(application.database_engine, autoflush=True,
- autocommit=False)
+ return create_session(application.database_engine, autoflush=True, autocommit=False)
+
metadata = MetaData()
session = scoped_session(new_db_session, local_manager.get_ident)
-blog_table = Table('blogs', metadata,
- Column('id', Integer, primary_key=True),
- Column('name', String(120)),
- Column('description', String),
- Column('url', String(200)),
- Column('feed_url', String(250))
+blog_table = Table(
+ "blogs",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(120)),
+ Column("description", String),
+ Column("url", String(200)),
+ Column("feed_url", String(250)),
)
-entry_table = Table('entries', metadata,
- Column('id', Integer, primary_key=True),
- Column('blog_id', Integer, ForeignKey('blogs.id')),
- Column('guid', String(200), unique=True),
- Column('title', String(140)),
- Column('url', String(200)),
- Column('text', String),
- Column('pub_date', DateTime),
- Column('last_update', DateTime)
+entry_table = Table(
+ "entries",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("blog_id", Integer, ForeignKey("blogs.id")),
+ Column("guid", String(200), unique=True),
+ Column("title", String(140)),
+ Column("url", String(200)),
+ Column("text", String),
+ Column("pub_date", DateTime),
+ Column("last_update", DateTime),
)
class Blog(object):
query = session.query_property()
- def __init__(self, name, url, feed_url, description=u''):
+ def __init__(self, name, url, feed_url, description=u""):
self.name = name
self.url = url
self.feed_url = feed_url
self.description = description
def __repr__(self):
- return '<%s %r>' % (self.__class__.__name__, self.url)
+ return "<%s %r>" % (self.__class__.__name__, self.url)
class Entry(object):
query = session.query_property()
def __repr__(self):
- return '<%s %r>' % (self.__class__.__name__, self.guid)
+ return "<%s %r>" % (self.__class__.__name__, self.guid)
mapper(Entry, entry_table)
-mapper(Blog, blog_table, properties=dict(
- entries=dynamic_loader(Entry, backref='blog')
-))
+mapper(Blog, blog_table, properties=dict(entries=dynamic_loader(Entry, backref="blog")))
diff --git a/examples/plnt/sync.py b/examples/plnt/sync.py
index b0265d85..e12ab844 100644
--- a/examples/plnt/sync.py
+++ b/examples/plnt/sync.py
@@ -8,14 +8,19 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-import feedparser
from datetime import datetime
+
+import feedparser
from werkzeug.utils import escape
-from plnt.database import Blog, Entry, session
-from plnt.utils import strip_tags, nl2p
+
+from .database import Blog
+from .database import Entry
+from .database import session
+from .utils import nl2p
+from .utils import strip_tags
-HTML_MIMETYPES = {'text/html', 'application/xhtml+xml'}
+HTML_MIMETYPES = {"text/html", "application/xhtml+xml"}
def sync():
@@ -31,7 +36,7 @@ def sync():
for entry in feed.entries:
# get the guid. either the id if specified, otherwise the link.
# if none is available we skip the entry.
- guid = entry.get('id') or entry.get('link')
+ guid = entry.get("id") or entry.get("link")
if not guid:
continue
@@ -41,17 +46,18 @@ def sync():
# get title, url and text. skip if no title or no text is
# given. if the link is missing we use the blog link.
- if 'title_detail' in entry:
- title = entry.title_detail.get('value') or ''
- if entry.title_detail.get('type') in HTML_MIMETYPES:
+ if "title_detail" in entry:
+ title = entry.title_detail.get("value") or ""
+ if entry.title_detail.get("type") in HTML_MIMETYPES:
title = strip_tags(title)
else:
title = escape(title)
else:
- title = entry.get('title')
- url = entry.get('link') or blog.blog_url
- text = 'content' in entry and entry.content[0] or \
- entry.get('summary_detail')
+ title = entry.get("title")
+ url = entry.get("link") or blog.blog_url
+ text = (
+ "content" in entry and entry.content[0] or entry.get("summary_detail")
+ )
if not title or not text:
continue
@@ -59,10 +65,10 @@ def sync():
# if we have an html text we use that, otherwise we HTML
# escape the text and use that one. We also handle XHTML
# with our tag soup parser for the moment.
- if text.get('type') not in HTML_MIMETYPES:
- text = escape(nl2p(text.get('value') or ''))
+ if text.get("type") not in HTML_MIMETYPES:
+ text = escape(nl2p(text.get("value") or ""))
else:
- text = text.get('value') or ''
+ text = text.get("value") or ""
# no text? continue
if not text.strip():
@@ -70,10 +76,12 @@ def sync():
# get the pub date and updated date. This is rather complex
# because different feeds do different stuff
- pub_date = entry.get('published_parsed') or \
- entry.get('created_parsed') or \
- entry.get('date_parsed')
- updated = entry.get('updated_parsed') or pub_date
+ pub_date = (
+ entry.get("published_parsed")
+ or entry.get("created_parsed")
+ or entry.get("date_parsed")
+ )
+ updated = entry.get("updated_parsed") or pub_date
pub_date = pub_date or updated
# if we don't have a pub_date we skip.
diff --git a/examples/plnt/utils.py b/examples/plnt/utils.py
index 957c28da..5c6f0d0b 100644
--- a/examples/plnt/utils.py
+++ b/examples/plnt/utils.py
@@ -10,13 +10,16 @@
"""
import re
from os import path
-from jinja2 import Environment, FileSystemLoader
+from jinja2 import Environment
+from jinja2 import FileSystemLoader
from werkzeug._compat import unichr
-from werkzeug.local import Local, LocalManager
+from werkzeug.local import Local
+from werkzeug.local import LocalManager
+from werkzeug.routing import Map
+from werkzeug.routing import Rule
from werkzeug.utils import cached_property
from werkzeug.wrappers import Response
-from werkzeug.routing import Map, Rule
# context locals. these two objects are use by the application to
@@ -29,24 +32,24 @@ local_manager = LocalManager([local])
# proxy objects
-request = local('request')
-application = local('application')
-url_adapter = local('url_adapter')
+request = local("request")
+application = local("application")
+url_adapter = local("url_adapter")
# let's use jinja for templates this time
-template_path = path.join(path.dirname(__file__), 'templates')
+template_path = path.join(path.dirname(__file__), "templates")
jinja_env = Environment(loader=FileSystemLoader(template_path))
# the collected url patterns
-url_map = Map([Rule('/shared/<path:file>', endpoint='shared')])
+url_map = Map([Rule("/shared/<path:file>", endpoint="shared")])
endpoints = {}
-_par_re = re.compile(r'\n{2,}')
-_entity_re = re.compile(r'&([^;]+);')
-_striptags_re = re.compile(r'(<!--.*-->|<[^>]*>)')
+_par_re = re.compile(r"\n{2,}")
+_entity_re = re.compile(r"&([^;]+);")
+_striptags_re = re.compile(r"(<!--.*-->|<[^>]*>)")
try:
from html.entities import name2codepoint
@@ -54,30 +57,32 @@ except ImportError:
from htmlentitydefs import name2codepoint
html_entities = name2codepoint.copy()
-html_entities['apos'] = 39
+html_entities["apos"] = 39
del name2codepoint
def expose(url_rule, endpoint=None, **kwargs):
"""Expose this function to the web layer."""
+
def decorate(f):
e = endpoint or f.__name__
endpoints[e] = f
url_map.add(Rule(url_rule, endpoint=e, **kwargs))
return f
+
return decorate
def render_template(template_name, **context):
"""Render a template into a response."""
tmpl = jinja_env.get_template(template_name)
- context['url_for'] = url_for
- return Response(tmpl.render(context), mimetype='text/html')
+ context["url_for"] = url_for
+ return Response(tmpl.render(context), mimetype="text/html")
def nl2p(s):
"""Add paragraphs to a text."""
- return u'\n'.join(u'<p>%s</p>' % p for p in _par_re.split(s))
+ return u"\n".join(u"<p>%s</p>" % p for p in _par_re.split(s))
def url_for(endpoint, **kw):
@@ -87,22 +92,24 @@ def url_for(endpoint, **kw):
def strip_tags(s):
"""Resolve HTML entities and remove tags from a string."""
+
def handle_match(m):
name = m.group(1)
if name in html_entities:
return unichr(html_entities[name])
- if name[:2] in ('#x', '#X'):
+ if name[:2] in ("#x", "#X"):
try:
return unichr(int(name[2:], 16))
except ValueError:
- return u''
- elif name.startswith('#'):
+ return u""
+ elif name.startswith("#"):
try:
return unichr(int(name[1:]))
except ValueError:
- return u''
- return u''
- return _entity_re.sub(handle_match, _striptags_re.sub('', s))
+ return u""
+ return u""
+
+ return _entity_re.sub(handle_match, _striptags_re.sub("", s))
class Pagination(object):
@@ -118,8 +125,11 @@ class Pagination(object):
@cached_property
def entries(self):
- return self.query.offset((self.page - 1) * self.per_page) \
- .limit(self.per_page).all()
+ return (
+ self.query.offset((self.page - 1) * self.per_page)
+ .limit(self.per_page)
+ .all()
+ )
@cached_property
def count(self):
diff --git a/examples/plnt/views.py b/examples/plnt/views.py
index 7397531c..d64e98e3 100644
--- a/examples/plnt/views.py
+++ b/examples/plnt/views.py
@@ -9,32 +9,35 @@
:license: BSD-3-Clause
"""
from datetime import date
-from plnt.database import Entry
-from plnt.utils import Pagination, expose, render_template
+
+from .database import Entry
+from .utils import expose
+from .utils import Pagination
+from .utils import render_template
#: number of items per page
PER_PAGE = 30
-@expose('/', defaults={'page': 1})
-@expose('/page/<int:page>')
+@expose("/", defaults={"page": 1})
+@expose("/page/<int:page>")
def index(request, page):
"""Show the index page or any an offset of it."""
days = []
days_found = set()
query = Entry.query.order_by(Entry.pub_date.desc())
- pagination = Pagination(query, PER_PAGE, page, 'index')
+ pagination = Pagination(query, PER_PAGE, page, "index")
for entry in pagination.entries:
day = date(*entry.pub_date.timetuple()[:3])
if day not in days_found:
days_found.add(day)
- days.append({'date': day, 'entries': []})
- days[-1]['entries'].append(entry)
- return render_template('index.html', days=days, pagination=pagination)
+ days.append({"date": day, "entries": []})
+ days[-1]["entries"].append(entry)
+ return render_template("index.html", days=days, pagination=pagination)
-@expose('/about')
+@expose("/about")
def about(request):
"""Show the about page, so that we have another view func ;-)"""
- return render_template('about.html')
+ return render_template("about.html")
diff --git a/examples/plnt/webapp.py b/examples/plnt/webapp.py
index b571841a..dee7336b 100644
--- a/examples/plnt/webapp.py
+++ b/examples/plnt/webapp.py
@@ -9,31 +9,31 @@
:license: BSD-3-Clause
"""
from os import path
-from sqlalchemy import create_engine
+from sqlalchemy import create_engine
+from werkzeug.exceptions import HTTPException
from werkzeug.middleware.shared_data import SharedDataMiddleware
from werkzeug.wrappers import Request
from werkzeug.wsgi import ClosingIterator
-from werkzeug.exceptions import HTTPException
-from plnt.utils import local, local_manager, url_map, endpoints
-from plnt.database import session, metadata
-# import the views module because it contains setup code
-import plnt.views
+from . import views # noqa: F401
+from .database import metadata
+from .database import session
+from .utils import endpoints
+from .utils import local
+from .utils import local_manager
+from .utils import url_map
#: path to shared data
-SHARED_DATA = path.join(path.dirname(__file__), 'shared')
+SHARED_DATA = path.join(path.dirname(__file__), "shared")
class Plnt(object):
-
def __init__(self, database_uri):
self.database_engine = create_engine(database_uri)
self._dispatch = local_manager.middleware(self.dispatch_request)
- self._dispatch = SharedDataMiddleware(self._dispatch, {
- '/shared': SHARED_DATA
- })
+ self._dispatch = SharedDataMiddleware(self._dispatch, {"/shared": SHARED_DATA})
def init_database(self):
metadata.create_all(self.database_engine)
@@ -50,8 +50,7 @@ class Plnt(object):
response = endpoints[endpoint](request, **values)
except HTTPException as e:
response = e
- return ClosingIterator(response(environ, start_response),
- session.remove)
+ return ClosingIterator(response(environ, start_response), session.remove)
def __call__(self, environ, start_response):
return self._dispatch(environ, start_response)
diff --git a/examples/shortly/shortly.py b/examples/shortly/shortly.py
index 3446de0f..a3ff9f7f 100644
--- a/examples/shortly/shortly.py
+++ b/examples/shortly/shortly.py
@@ -9,32 +9,35 @@
:license: BSD-3-Clause
"""
import os
-import redis
+import redis
+from jinja2 import Environment
+from jinja2 import FileSystemLoader
+from werkzeug.exceptions import HTTPException
+from werkzeug.exceptions import NotFound
from werkzeug.middleware.shared_data import SharedDataMiddleware
+from werkzeug.routing import Map
+from werkzeug.routing import Rule
from werkzeug.urls import url_parse
-from werkzeug.wrappers import Request, Response
-from werkzeug.routing import Map, Rule
-from werkzeug.exceptions import HTTPException, NotFound
from werkzeug.utils import redirect
-
-from jinja2 import Environment, FileSystemLoader
+from werkzeug.wrappers import Request
+from werkzeug.wrappers import Response
def base36_encode(number):
- assert number >= 0, 'positive integer required'
+ assert number >= 0, "positive integer required"
if number == 0:
- return '0'
+ return "0"
base36 = []
while number != 0:
number, i = divmod(number, 36)
- base36.append('0123456789abcdefghijklmnopqrstuvwxyz'[i])
- return ''.join(reversed(base36))
+ base36.append("0123456789abcdefghijklmnopqrstuvwxyz"[i])
+ return "".join(reversed(base36))
def is_valid_url(url):
parts = url_parse(url)
- return parts.scheme in ('http', 'https')
+ return parts.scheme in ("http", "https")
def get_hostname(url):
@@ -42,74 +45,77 @@ def get_hostname(url):
class Shortly(object):
-
def __init__(self, config):
- self.redis = redis.Redis(config['redis_host'], config['redis_port'])
- template_path = os.path.join(os.path.dirname(__file__), 'templates')
- self.jinja_env = Environment(loader=FileSystemLoader(template_path),
- autoescape=True)
- self.jinja_env.filters['hostname'] = get_hostname
-
- self.url_map = Map([
- Rule('/', endpoint='new_url'),
- Rule('/<short_id>', endpoint='follow_short_link'),
- Rule('/<short_id>+', endpoint='short_link_details')
- ])
+ self.redis = redis.Redis(config["redis_host"], config["redis_port"])
+ template_path = os.path.join(os.path.dirname(__file__), "templates")
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(template_path), autoescape=True
+ )
+ self.jinja_env.filters["hostname"] = get_hostname
+
+ self.url_map = Map(
+ [
+ Rule("/", endpoint="new_url"),
+ Rule("/<short_id>", endpoint="follow_short_link"),
+ Rule("/<short_id>+", endpoint="short_link_details"),
+ ]
+ )
def on_new_url(self, request):
error = None
- url = ''
- if request.method == 'POST':
- url = request.form['url']
+ url = ""
+ if request.method == "POST":
+ url = request.form["url"]
if not is_valid_url(url):
- error = 'Please enter a valid URL'
+ error = "Please enter a valid URL"
else:
short_id = self.insert_url(url)
- return redirect('/%s+' % short_id)
- return self.render_template('new_url.html', error=error, url=url)
+ return redirect("/%s+" % short_id)
+ return self.render_template("new_url.html", error=error, url=url)
def on_follow_short_link(self, request, short_id):
- link_target = self.redis.get('url-target:' + short_id)
+ link_target = self.redis.get("url-target:" + short_id)
if link_target is None:
raise NotFound()
- self.redis.incr('click-count:' + short_id)
+ self.redis.incr("click-count:" + short_id)
return redirect(link_target)
def on_short_link_details(self, request, short_id):
- link_target = self.redis.get('url-target:' + short_id)
+ link_target = self.redis.get("url-target:" + short_id)
if link_target is None:
raise NotFound()
- click_count = int(self.redis.get('click-count:' + short_id) or 0)
- return self.render_template('short_link_details.html',
+ click_count = int(self.redis.get("click-count:" + short_id) or 0)
+ return self.render_template(
+ "short_link_details.html",
link_target=link_target,
short_id=short_id,
- click_count=click_count
+ click_count=click_count,
)
def error_404(self):
- response = self.render_template('404.html')
+ response = self.render_template("404.html")
response.status_code = 404
return response
def insert_url(self, url):
- short_id = self.redis.get('reverse-url:' + url)
+ short_id = self.redis.get("reverse-url:" + url)
if short_id is not None:
return short_id
- url_num = self.redis.incr('last-url-id')
+ url_num = self.redis.incr("last-url-id")
short_id = base36_encode(url_num)
- self.redis.set('url-target:' + short_id, url)
- self.redis.set('reverse-url:' + url, short_id)
+ self.redis.set("url-target:" + short_id, url)
+ self.redis.set("reverse-url:" + url, short_id)
return short_id
def render_template(self, template_name, **context):
t = self.jinja_env.get_template(template_name)
- return Response(t.render(context), mimetype='text/html')
+ return Response(t.render(context), mimetype="text/html")
def dispatch_request(self, request):
adapter = self.url_map.bind_to_environ(request.environ)
try:
endpoint, values = adapter.match()
- return getattr(self, 'on_' + endpoint)(request, **values)
+ return getattr(self, "on_" + endpoint)(request, **values)
except NotFound:
return self.error_404()
except HTTPException as e:
@@ -124,19 +130,17 @@ class Shortly(object):
return self.wsgi_app(environ, start_response)
-def create_app(redis_host='localhost', redis_port=6379, with_static=True):
- app = Shortly({
- 'redis_host': redis_host,
- 'redis_port': redis_port
- })
+def create_app(redis_host="localhost", redis_port=6379, with_static=True):
+ app = Shortly({"redis_host": redis_host, "redis_port": redis_port})
if with_static:
- app.wsgi_app = SharedDataMiddleware(app.wsgi_app, {
- '/static': os.path.join(os.path.dirname(__file__), 'static')
- })
+ app.wsgi_app = SharedDataMiddleware(
+ app.wsgi_app, {"/static": os.path.join(os.path.dirname(__file__), "static")}
+ )
return app
-if __name__ == '__main__':
+if __name__ == "__main__":
from werkzeug.serving import run_simple
+
app = create_app()
- run_simple('127.0.0.1', 5000, app, use_debugger=True, use_reloader=True)
+ run_simple("127.0.0.1", 5000, app, use_debugger=True, use_reloader=True)
diff --git a/examples/shorty/application.py b/examples/shorty/application.py
index cfec075f..af793444 100644
--- a/examples/shorty/application.py
+++ b/examples/shorty/application.py
@@ -1,24 +1,25 @@
from sqlalchemy import create_engine
-
+from werkzeug.exceptions import HTTPException
+from werkzeug.exceptions import NotFound
from werkzeug.middleware.shared_data import SharedDataMiddleware
from werkzeug.wrappers import Request
from werkzeug.wsgi import ClosingIterator
-from werkzeug.exceptions import HTTPException, NotFound
-from shorty.utils import STATIC_PATH, session, local, local_manager, \
- metadata, url_map
-from shorty import views
+from . import views
+from .utils import local
+from .utils import local_manager
+from .utils import metadata
+from .utils import session
+from .utils import STATIC_PATH
+from .utils import url_map
class Shorty(object):
-
def __init__(self, db_uri):
local.application = self
self.database_engine = create_engine(db_uri, convert_unicode=True)
- self.dispatch = SharedDataMiddleware(self.dispatch, {
- '/static': STATIC_PATH
- })
+ self.dispatch = SharedDataMiddleware(self.dispatch, {"/static": STATIC_PATH})
def init_database(self):
metadata.create_all(self.database_engine)
@@ -36,8 +37,9 @@ class Shorty(object):
response.status_code = 404
except HTTPException as e:
response = e
- return ClosingIterator(response(environ, start_response),
- [session.remove, local_manager.cleanup])
+ return ClosingIterator(
+ response(environ, start_response), [session.remove, local_manager.cleanup]
+ )
def __call__(self, environ, start_response):
return self.dispatch(environ, start_response)
diff --git a/examples/shorty/models.py b/examples/shorty/models.py
index 78b5f1df..7d0df5bd 100644
--- a/examples/shorty/models.py
+++ b/examples/shorty/models.py
@@ -1,15 +1,27 @@
from datetime import datetime
-from sqlalchemy import Table, Column, String, Boolean, DateTime
+
+from sqlalchemy import Boolean
+from sqlalchemy import Column
+from sqlalchemy import DateTime
+from sqlalchemy import String
+from sqlalchemy import Table
from sqlalchemy.orm import mapper
-from shorty.utils import session, metadata, url_for, get_random_uid
-url_table = Table('urls', metadata,
- Column('uid', String(140), primary_key=True),
- Column('target', String(500)),
- Column('added', DateTime),
- Column('public', Boolean)
+from .utils import get_random_uid
+from .utils import metadata
+from .utils import session
+from .utils import url_for
+
+url_table = Table(
+ "urls",
+ metadata,
+ Column("uid", String(140), primary_key=True),
+ Column("target", String(500)),
+ Column("added", DateTime),
+ Column("public", Boolean),
)
+
class URL(object):
query = session.query_property()
@@ -27,9 +39,10 @@ class URL(object):
@property
def short_url(self):
- return url_for('link', uid=self.uid, _external=True)
+ return url_for("link", uid=self.uid, _external=True)
def __repr__(self):
- return '<URL %r>' % self.uid
+ return "<URL %r>" % self.uid
+
mapper(URL, url_table)
diff --git a/examples/shorty/utils.py b/examples/shorty/utils.py
index bbcb0f20..f61f9ef9 100644
--- a/examples/shorty/utils.py
+++ b/examples/shorty/utils.py
@@ -1,57 +1,72 @@
from os import path
-from random import sample, randrange
-from jinja2 import Environment, FileSystemLoader
-from werkzeug.local import Local, LocalManager
+from random import randrange
+from random import sample
+
+from jinja2 import Environment
+from jinja2 import FileSystemLoader
+from sqlalchemy import MetaData
+from sqlalchemy.orm import create_session
+from sqlalchemy.orm import scoped_session
+from werkzeug.local import Local
+from werkzeug.local import LocalManager
+from werkzeug.routing import Map
+from werkzeug.routing import Rule
from werkzeug.urls import url_parse
from werkzeug.utils import cached_property
from werkzeug.wrappers import Response
-from werkzeug.routing import Map, Rule
-from sqlalchemy import MetaData
-from sqlalchemy.orm import create_session, scoped_session
-TEMPLATE_PATH = path.join(path.dirname(__file__), 'templates')
-STATIC_PATH = path.join(path.dirname(__file__), 'static')
-ALLOWED_SCHEMES = frozenset(['http', 'https', 'ftp', 'ftps'])
-URL_CHARS = 'abcdefghijkmpqrstuvwxyzABCDEFGHIJKLMNPQRST23456789'
+TEMPLATE_PATH = path.join(path.dirname(__file__), "templates")
+STATIC_PATH = path.join(path.dirname(__file__), "static")
+ALLOWED_SCHEMES = frozenset(["http", "https", "ftp", "ftps"])
+URL_CHARS = "abcdefghijkmpqrstuvwxyzABCDEFGHIJKLMNPQRST23456789"
local = Local()
local_manager = LocalManager([local])
-application = local('application')
+application = local("application")
metadata = MetaData()
-url_map = Map([Rule('/static/<file>', endpoint='static', build_only=True)])
+url_map = Map([Rule("/static/<file>", endpoint="static", build_only=True)])
-session = scoped_session(lambda: create_session(application.database_engine,
- autocommit=False,
- autoflush=False))
+session = scoped_session(
+ lambda: create_session(
+ application.database_engine, autocommit=False, autoflush=False
+ )
+)
jinja_env = Environment(loader=FileSystemLoader(TEMPLATE_PATH))
def expose(rule, **kw):
def decorate(f):
- kw['endpoint'] = f.__name__
+ kw["endpoint"] = f.__name__
url_map.add(Rule(rule, **kw))
return f
+
return decorate
+
def url_for(endpoint, _external=False, **values):
return local.url_adapter.build(endpoint, values, force_external=_external)
-jinja_env.globals['url_for'] = url_for
+
+
+jinja_env.globals["url_for"] = url_for
+
def render_template(template, **context):
- return Response(jinja_env.get_template(template).render(**context),
- mimetype='text/html')
+ return Response(
+ jinja_env.get_template(template).render(**context), mimetype="text/html"
+ )
+
def validate_url(url):
return url_parse(url)[0] in ALLOWED_SCHEMES
+
def get_random_uid():
- return ''.join(sample(URL_CHARS, randrange(3, 9)))
+ return "".join(sample(URL_CHARS, randrange(3, 9)))
class Pagination(object):
-
def __init__(self, query, per_page, page, endpoint):
self.query = query
self.per_page = per_page
@@ -64,8 +79,11 @@ class Pagination(object):
@cached_property
def entries(self):
- return self.query.offset((self.page - 1) * self.per_page) \
- .limit(self.per_page).all()
+ return (
+ self.query.offset((self.page - 1) * self.per_page)
+ .limit(self.per_page)
+ .all()
+ )
has_previous = property(lambda self: self.page > 1)
has_next = property(lambda self: self.page < self.pages)
diff --git a/examples/shorty/views.py b/examples/shorty/views.py
index fa4269a3..7a1ee20b 100644
--- a/examples/shorty/views.py
+++ b/examples/shorty/views.py
@@ -1,52 +1,62 @@
-from werkzeug.utils import redirect
from werkzeug.exceptions import NotFound
-from shorty.utils import session, Pagination, render_template, expose, \
- validate_url, url_for
-from shorty.models import URL
+from werkzeug.utils import redirect
+
+from .models import URL
+from .utils import expose
+from .utils import Pagination
+from .utils import render_template
+from .utils import session
+from .utils import url_for
+from .utils import validate_url
+
-@expose('/')
+@expose("/")
def new(request):
- error = url = ''
- if request.method == 'POST':
- url = request.form.get('url')
- alias = request.form.get('alias')
+ error = url = ""
+ if request.method == "POST":
+ url = request.form.get("url")
+ alias = request.form.get("alias")
if not validate_url(url):
error = "I'm sorry but you cannot shorten this URL."
elif alias:
if len(alias) > 140:
- error = 'Your alias is too long'
- elif '/' in alias:
- error = 'Your alias might not include a slash'
+ error = "Your alias is too long"
+ elif "/" in alias:
+ error = "Your alias might not include a slash"
elif URL.query.get(alias):
- error = 'The alias you have requested exists already'
+ error = "The alias you have requested exists already"
if not error:
- uid = URL(url, 'private' not in request.form, alias).uid
+ uid = URL(url, "private" not in request.form, alias).uid
session.commit()
- return redirect(url_for('display', uid=uid))
- return render_template('new.html', error=error, url=url)
+ return redirect(url_for("display", uid=uid))
+ return render_template("new.html", error=error, url=url)
-@expose('/display/<uid>')
+
+@expose("/display/<uid>")
def display(request, uid):
url = URL.query.get(uid)
if not url:
raise NotFound()
- return render_template('display.html', url=url)
+ return render_template("display.html", url=url)
+
-@expose('/u/<uid>')
+@expose("/u/<uid>")
def link(request, uid):
url = URL.query.get(uid)
if not url:
raise NotFound()
return redirect(url.target, 301)
-@expose('/list/', defaults={'page': 1})
-@expose('/list/<int:page>')
+
+@expose("/list/", defaults={"page": 1})
+@expose("/list/<int:page>")
def list(request, page):
query = URL.query.filter_by(public=True)
- pagination = Pagination(query, 30, page, 'list')
+ pagination = Pagination(query, 30, page, "list")
if pagination.page > 1 and not pagination.entries:
raise NotFound()
- return render_template('list.html', pagination=pagination)
+ return render_template("list.html", pagination=pagination)
+
def not_found(request):
- return render_template('not_found.html')
+ return render_template("not_found.html")
diff --git a/examples/simplewiki/__init__.py b/examples/simplewiki/__init__.py
index c74c2cf6..591a06e9 100644
--- a/examples/simplewiki/__init__.py
+++ b/examples/simplewiki/__init__.py
@@ -9,4 +9,4 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-from simplewiki.application import SimpleWiki
+from .application import SimpleWiki
diff --git a/examples/simplewiki/actions.py b/examples/simplewiki/actions.py
index 29b2ff5b..c9a88b41 100644
--- a/examples/simplewiki/actions.py
+++ b/examples/simplewiki/actions.py
@@ -13,15 +13,22 @@
:license: BSD-3-Clause
"""
from difflib import unified_diff
-from simplewiki.utils import Response, generate_template, \
- href, format_datetime
-from simplewiki.database import RevisionedPage, Page, Revision, session
+
from werkzeug.utils import redirect
+from .database import Page
+from .database import Revision
+from .database import RevisionedPage
+from .database import session
+from .utils import format_datetime
+from .utils import generate_template
+from .utils import href
+from .utils import Response
+
def on_show(request, page_name):
"""Displays the page the user requests."""
- revision_id = request.args.get('rev', type=int)
+ revision_id = request.args.get("rev", type=int)
query = RevisionedPage.query.filter_by(name=page_name)
if revision_id:
query = query.filter_by(revision_id=revision_id)
@@ -32,32 +39,32 @@ def on_show(request, page_name):
page = query.first()
if page is None:
return page_missing(request, page_name, revision_requested)
- return Response(generate_template('action_show.html',
- page=page
- ))
+ return Response(generate_template("action_show.html", page=page))
def on_edit(request, page_name):
"""Edit the current revision of a page."""
- change_note = error = ''
- revision = Revision.query.filter(
- (Page.name == page_name) &
- (Page.page_id == Revision.page_id)
- ).order_by(Revision.revision_id.desc()).first()
+ change_note = error = ""
+ revision = (
+ Revision.query.filter(
+ (Page.name == page_name) & (Page.page_id == Revision.page_id)
+ )
+ .order_by(Revision.revision_id.desc())
+ .first()
+ )
if revision is None:
page = None
else:
page = revision.page
- if request.method == 'POST':
- text = request.form.get('text')
- if request.form.get('cancel') or \
- revision and revision.text == text:
+ if request.method == "POST":
+ text = request.form.get("text")
+ if request.form.get("cancel") or revision and revision.text == text:
return redirect(href(page.name))
elif not text:
- error = 'You cannot save empty revisions.'
+ error = "You cannot save empty revisions."
else:
- change_note = request.form.get('change_note', '')
+ change_note = request.form.get("change_note", "")
if page is None:
page = Page(page_name)
session.add(page)
@@ -65,14 +72,17 @@ def on_edit(request, page_name):
session.commit()
return redirect(href(page.name))
- return Response(generate_template('action_edit.html',
- revision=revision,
- page=page,
- new=page is None,
- page_name=page_name,
- change_note=change_note,
- error=error
- ))
+ return Response(
+ generate_template(
+ "action_edit.html",
+ revision=revision,
+ page=page,
+ new=page is None,
+ page_name=page_name,
+ change_note=change_note,
+ error=error,
+ )
+ )
def on_log(request, page_name):
@@ -80,109 +90,117 @@ def on_log(request, page_name):
page = Page.query.filter_by(name=page_name).first()
if page is None:
return page_missing(request, page_name, False)
- return Response(generate_template('action_log.html',
- page=page
- ))
+ return Response(generate_template("action_log.html", page=page))
def on_diff(request, page_name):
"""Show the diff between two revisions."""
- old = request.args.get('old', type=int)
- new = request.args.get('new', type=int)
- error = ''
+ old = request.args.get("old", type=int)
+ new = request.args.get("new", type=int)
+ error = ""
diff = page = old_rev = new_rev = None
if not (old and new):
- error = 'No revisions specified.'
+ error = "No revisions specified."
else:
- revisions = dict((x.revision_id, x) for x in Revision.query.filter(
- (Revision.revision_id.in_((old, new))) &
- (Revision.page_id == Page.page_id) &
- (Page.name == page_name)
- ))
+ revisions = dict(
+ (x.revision_id, x)
+ for x in Revision.query.filter(
+ (Revision.revision_id.in_((old, new)))
+ & (Revision.page_id == Page.page_id)
+ & (Page.name == page_name)
+ )
+ )
if len(revisions) != 2:
- error = 'At least one of the revisions requested ' \
- 'does not exist.'
+ error = "At least one of the revisions requested does not exist."
else:
new_rev = revisions[new]
old_rev = revisions[old]
page = old_rev.page
diff = unified_diff(
- (old_rev.text + '\n').splitlines(True),
- (new_rev.text + '\n').splitlines(True),
- page.name, page.name,
+ (old_rev.text + "\n").splitlines(True),
+ (new_rev.text + "\n").splitlines(True),
+ page.name,
+ page.name,
format_datetime(old_rev.timestamp),
format_datetime(new_rev.timestamp),
- 3
+ 3,
)
- return Response(generate_template('action_diff.html',
- error=error,
- old_revision=old_rev,
- new_revision=new_rev,
- page=page,
- diff=diff
- ))
+ return Response(
+ generate_template(
+ "action_diff.html",
+ error=error,
+ old_revision=old_rev,
+ new_revision=new_rev,
+ page=page,
+ diff=diff,
+ )
+ )
def on_revert(request, page_name):
"""Revert an old revision."""
- rev_id = request.args.get('rev', type=int)
+ rev_id = request.args.get("rev", type=int)
old_revision = page = None
- error = 'No such revision'
+ error = "No such revision"
- if request.method == 'POST' and request.form.get('cancel'):
+ if request.method == "POST" and request.form.get("cancel"):
return redirect(href(page_name))
if rev_id:
old_revision = Revision.query.filter(
- (Revision.revision_id == rev_id) &
- (Revision.page_id == Page.page_id) &
- (Page.name == page_name)
+ (Revision.revision_id == rev_id)
+ & (Revision.page_id == Page.page_id)
+ & (Page.name == page_name)
).first()
if old_revision:
- new_revision = Revision.query.filter(
- (Revision.page_id == Page.page_id) &
- (Page.name == page_name)
- ).order_by(Revision.revision_id.desc()).first()
+ new_revision = (
+ Revision.query.filter(
+ (Revision.page_id == Page.page_id) & (Page.name == page_name)
+ )
+ .order_by(Revision.revision_id.desc())
+ .first()
+ )
if old_revision == new_revision:
- error = 'You tried to revert the current active ' \
- 'revision.'
+ error = "You tried to revert the current active revision."
elif old_revision.text == new_revision.text:
- error = 'There are no changes between the current ' \
- 'revision and the revision you want to ' \
- 'restore.'
+ error = (
+ "There are no changes between the current "
+ "revision and the revision you want to "
+ "restore."
+ )
else:
- error = ''
+ error = ""
page = old_revision.page
- if request.method == 'POST':
- change_note = request.form.get('change_note', '')
- change_note = 'revert' + (change_note and ': ' +
- change_note or '')
- session.add(Revision(page, old_revision.text,
- change_note))
+ if request.method == "POST":
+ change_note = request.form.get("change_note", "")
+ change_note = "revert" + (change_note and ": " + change_note or "")
+ session.add(Revision(page, old_revision.text, change_note))
session.commit()
return redirect(href(page_name))
- return Response(generate_template('action_revert.html',
- error=error,
- old_revision=old_revision,
- page=page
- ))
+ return Response(
+ generate_template(
+ "action_revert.html", error=error, old_revision=old_revision, page=page
+ )
+ )
def page_missing(request, page_name, revision_requested, protected=False):
"""Displayed if page or revision does not exist."""
- return Response(generate_template('page_missing.html',
- page_name=page_name,
- revision_requested=revision_requested,
- protected=protected
- ), status=404)
+ return Response(
+ generate_template(
+ "page_missing.html",
+ page_name=page_name,
+ revision_requested=revision_requested,
+ protected=protected,
+ ),
+ status=404,
+ )
def missing_action(request, action):
"""Displayed if a user tried to access a action that does not exist."""
- return Response(generate_template('missing_action.html',
- action=action
- ), status=404)
+ return Response(generate_template("missing_action.html", action=action), status=404)
diff --git a/examples/simplewiki/application.py b/examples/simplewiki/application.py
index b994b17d..5f85eef3 100644
--- a/examples/simplewiki/application.py
+++ b/examples/simplewiki/application.py
@@ -11,19 +11,25 @@
:license: BSD-3-Clause
"""
from os import path
-from sqlalchemy import create_engine
+from sqlalchemy import create_engine
from werkzeug.middleware.shared_data import SharedDataMiddleware
from werkzeug.utils import redirect
from werkzeug.wsgi import ClosingIterator
-from simplewiki.utils import Request, local, local_manager, href
-from simplewiki.database import session, metadata
-from simplewiki import actions
-from simplewiki.specialpages import pages, page_not_found
+
+from . import actions
+from .database import metadata
+from .database import session
+from .specialpages import page_not_found
+from .specialpages import pages
+from .utils import href
+from .utils import local
+from .utils import local_manager
+from .utils import Request
#: path to shared data
-SHARED_DATA = path.join(path.dirname(__file__), 'shared')
+SHARED_DATA = path.join(path.dirname(__file__), "shared")
class SimpleWiki(object):
@@ -37,9 +43,9 @@ class SimpleWiki(object):
# apply our middlewares. we apply the middlewars *inside* the
# application and not outside of it so that we never lose the
# reference to the `SimpleWiki` object.
- self._dispatch = SharedDataMiddleware(self.dispatch_request, {
- '/_shared': SHARED_DATA
- })
+ self._dispatch = SharedDataMiddleware(
+ self.dispatch_request, {"/_shared": SHARED_DATA}
+ )
# free the context locals at the end of the request
self._dispatch = local_manager.make_middleware(self._dispatch)
@@ -66,16 +72,15 @@ class SimpleWiki(object):
# get the current action from the url and normalize the page name
# which is just the request path
- action_name = request.args.get('action') or 'show'
- page_name = u'_'.join([x for x in request.path.strip('/')
- .split() if x])
+ action_name = request.args.get("action") or "show"
+ page_name = u"_".join([x for x in request.path.strip("/").split() if x])
# redirect to the Main_Page if the user requested the index
if not page_name:
- response = redirect(href('Main_Page'))
+ response = redirect(href("Main_Page"))
# check special pages
- elif page_name.startswith('Special:'):
+ elif page_name.startswith("Special:"):
if page_name[8:] not in pages:
response = page_not_found(request, page_name)
else:
@@ -85,15 +90,14 @@ class SimpleWiki(object):
# action module. It's "on_" + the action name. If it doesn't
# exists call the missing_action method from the same module.
else:
- action = getattr(actions, 'on_' + action_name, None)
+ action = getattr(actions, "on_" + action_name, None)
if action is None:
response = actions.missing_action(request, action_name)
else:
response = action(request, page_name)
# make sure the session is removed properly
- return ClosingIterator(response(environ, start_response),
- session.remove)
+ return ClosingIterator(response(environ, start_response), session.remove)
def __call__(self, environ, start_response):
"""Just forward a WSGI call to the first internal middleware."""
diff --git a/examples/simplewiki/database.py b/examples/simplewiki/database.py
index d5df7387..f0cec34e 100644
--- a/examples/simplewiki/database.py
+++ b/examples/simplewiki/database.py
@@ -9,11 +9,23 @@
:license: BSD-3-Clause
"""
from datetime import datetime
-from sqlalchemy import Table, Column, Integer, String, DateTime, \
- ForeignKey, MetaData, join
-from sqlalchemy.orm import relation, create_session, scoped_session, \
- mapper
-from simplewiki.utils import application, local_manager, parse_creole
+
+from sqlalchemy import Column
+from sqlalchemy import DateTime
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import join
+from sqlalchemy import MetaData
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy.orm import create_session
+from sqlalchemy.orm import mapper
+from sqlalchemy.orm import relation
+from sqlalchemy.orm import scoped_session
+
+from .utils import application
+from .utils import local_manager
+from .utils import parse_creole
# create a global metadata
@@ -28,8 +40,7 @@ def new_db_session():
application. If there is no application bound to the context it
raises an exception.
"""
- return create_session(application.database_engine, autoflush=True,
- autocommit=False)
+ return create_session(application.database_engine, autoflush=True, autocommit=False)
# and create a new global session factory. Calling this object gives
@@ -38,17 +49,21 @@ session = scoped_session(new_db_session, local_manager.get_ident)
# our database tables.
-page_table = Table('pages', metadata,
- Column('page_id', Integer, primary_key=True),
- Column('name', String(60), unique=True)
+page_table = Table(
+ "pages",
+ metadata,
+ Column("page_id", Integer, primary_key=True),
+ Column("name", String(60), unique=True),
)
-revision_table = Table('revisions', metadata,
- Column('revision_id', Integer, primary_key=True),
- Column('page_id', Integer, ForeignKey('pages.page_id')),
- Column('timestamp', DateTime),
- Column('text', String),
- Column('change_note', String(200))
+revision_table = Table(
+ "revisions",
+ metadata,
+ Column("revision_id", Integer, primary_key=True),
+ Column("page_id", Integer, ForeignKey("pages.page_id")),
+ Column("timestamp", DateTime),
+ Column("text", String),
+ Column("change_note", String(200)),
)
@@ -59,9 +74,10 @@ class Revision(object):
new revisions. It's also used for the diff system and the revision
log.
"""
+
query = session.query_property()
- def __init__(self, page, text, change_note='', timestamp=None):
+ def __init__(self, page, text, change_note="", timestamp=None):
if isinstance(page, int):
self.page_id = page
else:
@@ -75,11 +91,7 @@ class Revision(object):
return parse_creole(self.text)
def __repr__(self):
- return '<%s %r:%r>' % (
- self.__class__.__name__,
- self.page_id,
- self.revision_id
- )
+ return "<%s %r:%r>" % (self.__class__.__name__, self.page_id, self.revision_id)
class Page(object):
@@ -87,6 +99,7 @@ class Page(object):
Represents a simple page without any revisions. This is for example
used in the page index where the page contents are not relevant.
"""
+
query = session.query_property()
def __init__(self, name):
@@ -94,10 +107,10 @@ class Page(object):
@property
def title(self):
- return self.name.replace('_', ' ')
+ return self.name.replace("_", " ")
def __repr__(self):
- return '<%s %r>' % (self.__class__.__name__, self.name)
+ return "<%s %r>" % (self.__class__.__name__, self.name)
class RevisionedPage(Page, Revision):
@@ -106,26 +119,32 @@ class RevisionedPage(Page, Revision):
and the ability of SQLAlchemy to map to joins we can combine `Page` and
`Revision` into one class here.
"""
+
query = session.query_property()
def __init__(self):
- raise TypeError('cannot create WikiPage instances, use the Page and '
- 'Revision classes for data manipulation.')
+ raise TypeError(
+ "cannot create WikiPage instances, use the Page and "
+ "Revision classes for data manipulation."
+ )
def __repr__(self):
- return '<%s %r:%r>' % (
- self.__class__.__name__,
- self.name,
- self.revision_id
- )
+ return "<%s %r:%r>" % (self.__class__.__name__, self.name, self.revision_id)
# setup mappers
mapper(Revision, revision_table)
-mapper(Page, page_table, properties=dict(
- revisions=relation(Revision, backref='page',
- order_by=Revision.revision_id.desc())
-))
-mapper(RevisionedPage, join(page_table, revision_table), properties=dict(
- page_id=[page_table.c.page_id, revision_table.c.page_id],
-))
+mapper(
+ Page,
+ page_table,
+ properties=dict(
+ revisions=relation(
+ Revision, backref="page", order_by=Revision.revision_id.desc()
+ )
+ ),
+)
+mapper(
+ RevisionedPage,
+ join(page_table, revision_table),
+ properties=dict(page_id=[page_table.c.page_id, revision_table.c.page_id]),
+)
diff --git a/examples/simplewiki/specialpages.py b/examples/simplewiki/specialpages.py
index 9717d6c1..9636f7ca 100644
--- a/examples/simplewiki/specialpages.py
+++ b/examples/simplewiki/specialpages.py
@@ -9,10 +9,12 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-from simplewiki.utils import Response, Pagination, generate_template
-from simplewiki.database import RevisionedPage, Page
-from simplewiki.actions import page_missing
-
+from .actions import page_missing
+from .database import Page
+from .database import RevisionedPage
+from .utils import generate_template
+from .utils import Pagination
+from .utils import Response
def page_index(request):
@@ -20,19 +22,21 @@ def page_index(request):
letters = {}
for page in Page.query.order_by(Page.name):
letters.setdefault(page.name.capitalize()[0], []).append(page)
- return Response(generate_template('page_index.html',
- letters=sorted(letters.items())
- ))
+ return Response(
+ generate_template("page_index.html", letters=sorted(letters.items()))
+ )
def recent_changes(request):
"""Display the recent changes."""
- page = max(1, request.args.get('page', type=int))
- query = RevisionedPage.query \
- .order_by(RevisionedPage.revision_id.desc())
- return Response(generate_template('recent_changes.html',
- pagination=Pagination(query, 20, page, 'Special:Recent_Changes')
- ))
+ page = max(1, request.args.get("page", type=int))
+ query = RevisionedPage.query.order_by(RevisionedPage.revision_id.desc())
+ return Response(
+ generate_template(
+ "recent_changes.html",
+ pagination=Pagination(query, 20, page, "Special:Recent_Changes"),
+ )
+ )
def page_not_found(request, page_name):
@@ -43,7 +47,4 @@ def page_not_found(request, page_name):
return page_missing(request, page_name, True)
-pages = {
- 'Index': page_index,
- 'Recent_Changes': recent_changes
-}
+pages = {"Index": page_index, "Recent_Changes": recent_changes}
diff --git a/examples/simplewiki/templates/macros.xml b/examples/simplewiki/templates/macros.xml
index 14bebd83..28ce06b9 100644
--- a/examples/simplewiki/templates/macros.xml
+++ b/examples/simplewiki/templates/macros.xml
@@ -1,6 +1,6 @@
<div xmlns="http://www.w3.org/1999/xhtml" xmlns:py="http://genshi.edgewall.org/"
py:strip="">
-
+
<py:def function="render_pagination(pagination)">
<div class="pagination" py:if="pagination.pages > 1">
<py:choose test="pagination.has_previous">
diff --git a/examples/simplewiki/utils.py b/examples/simplewiki/utils.py
index f6f11d29..66289e8b 100644
--- a/examples/simplewiki/utils.py
+++ b/examples/simplewiki/utils.py
@@ -9,20 +9,25 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-import creoleparser
from os import path
+
+import creoleparser
from genshi import Stream
from genshi.template import TemplateLoader
-from werkzeug.local import Local, LocalManager
-from werkzeug.urls import url_encode, url_quote
+from werkzeug.local import Local
+from werkzeug.local import LocalManager
+from werkzeug.urls import url_encode
+from werkzeug.urls import url_quote
from werkzeug.utils import cached_property
-from werkzeug.wrappers import BaseRequest, BaseResponse
+from werkzeug.wrappers import BaseRequest
+from werkzeug.wrappers import BaseResponse
# calculate the path to the templates an create the template loader
-TEMPLATE_PATH = path.join(path.dirname(__file__), 'templates')
-template_loader = TemplateLoader(TEMPLATE_PATH, auto_reload=True,
- variable_lookup='lenient')
+TEMPLATE_PATH = path.join(path.dirname(__file__), "templates")
+template_loader = TemplateLoader(
+ TEMPLATE_PATH, auto_reload=True, variable_lookup="lenient"
+)
# context locals. these two objects are use by the application to
@@ -30,27 +35,25 @@ template_loader = TemplateLoader(TEMPLATE_PATH, auto_reload=True,
# current thread and the current greenlet if there is greenlet support.
local = Local()
local_manager = LocalManager([local])
-request = local('request')
-application = local('application')
+request = local("request")
+application = local("application")
# create a new creole parser
creole_parser = creoleparser.Parser(
- dialect=creoleparser.create_dialect(creoleparser.creole10_base,
- wiki_links_base_url='',
+ dialect=creoleparser.create_dialect(
+ creoleparser.creole10_base,
+ wiki_links_base_url="",
wiki_links_path_func=lambda page_name: href(page_name),
- wiki_links_space_char='_',
- no_wiki_monospace=True
+ wiki_links_space_char="_",
+ no_wiki_monospace=True,
),
- method='html'
+ method="html",
)
def generate_template(template_name, **context):
"""Load and generate a template."""
- context.update(
- href=href,
- format_datetime=format_datetime
- )
+ context.update(href=href, format_datetime=format_datetime)
return template_loader.load(template_name).generate(**context)
@@ -64,17 +67,17 @@ def href(*args, **kw):
Simple function for URL generation. Position arguments are used for the
URL path and keyword arguments are used for the url parameters.
"""
- result = [(request.script_root if request else '') + '/']
+ result = [(request.script_root if request else "") + "/"]
for idx, arg in enumerate(args):
- result.append(('/' if idx else '') + url_quote(arg))
+ result.append(("/" if idx else "") + url_quote(arg))
if kw:
- result.append('?' + url_encode(kw))
- return ''.join(result)
+ result.append("?" + url_encode(kw))
+ return "".join(result)
def format_datetime(obj):
"""Format a datetime object."""
- return obj.strftime('%Y-%m-%d %H:%M')
+ return obj.strftime("%Y-%m-%d %H:%M")
class Request(BaseRequest):
@@ -94,14 +97,14 @@ class Response(BaseResponse):
to html. This makes it possible to switch to xhtml or html5 easily.
"""
- default_mimetype = 'text/html'
+ default_mimetype = "text/html"
- def __init__(self, response=None, status=200, headers=None, mimetype=None,
- content_type=None):
+ def __init__(
+ self, response=None, status=200, headers=None, mimetype=None, content_type=None
+ ):
if isinstance(response, Stream):
- response = response.render('html', encoding=None, doctype='html')
- BaseResponse.__init__(self, response, status, headers, mimetype,
- content_type)
+ response = response.render("html", encoding=None, doctype="html")
+ BaseResponse.__init__(self, response, status, headers, mimetype, content_type)
class Pagination(object):
@@ -118,8 +121,11 @@ class Pagination(object):
@cached_property
def entries(self):
- return self.query.offset((self.page - 1) * self.per_page) \
- .limit(self.per_page).all()
+ return (
+ self.query.offset((self.page - 1) * self.per_page)
+ .limit(self.per_page)
+ .all()
+ )
@property
def has_previous(self):
diff --git a/examples/upload.py b/examples/upload.py
index 88af63aa..6f520b42 100644
--- a/examples/upload.py
+++ b/examples/upload.py
@@ -9,36 +9,39 @@
:license: BSD-3-Clause
"""
from werkzeug.serving import run_simple
-from werkzeug.wrappers import BaseRequest, BaseResponse
+from werkzeug.wrappers import BaseRequest
+from werkzeug.wrappers import BaseResponse
from werkzeug.wsgi import wrap_file
def view_file(req):
- if not 'uploaded_file' in req.files:
- return BaseResponse('no file uploaded')
- f = req.files['uploaded_file']
- return BaseResponse(wrap_file(req.environ, f), mimetype=f.content_type,
- direct_passthrough=True)
+ if "uploaded_file" not in req.files:
+ return BaseResponse("no file uploaded")
+ f = req.files["uploaded_file"]
+ return BaseResponse(
+ wrap_file(req.environ, f), mimetype=f.content_type, direct_passthrough=True
+ )
def upload_file(req):
- return BaseResponse('''
- <h1>Upload File</h1>
- <form action="" method="post" enctype="multipart/form-data">
- <input type="file" name="uploaded_file">
- <input type="submit" value="Upload">
- </form>
- ''', mimetype='text/html')
+ return BaseResponse(
+ """<h1>Upload File</h1>
+ <form action="" method="post" enctype="multipart/form-data">
+ <input type="file" name="uploaded_file">
+ <input type="submit" value="Upload">
+ </form>""",
+ mimetype="text/html",
+ )
def application(environ, start_response):
req = BaseRequest(environ)
- if req.method == 'POST':
+ if req.method == "POST":
resp = view_file(req)
else:
resp = upload_file(req)
return resp(environ, start_response)
-if __name__ == '__main__':
- run_simple('localhost', 5000, application, use_debugger=True)
+if __name__ == "__main__":
+ run_simple("localhost", 5000, application, use_debugger=True)
diff --git a/examples/webpylike/example.py b/examples/webpylike/example.py
index 7dc52a87..0bb6896d 100644
--- a/examples/webpylike/example.py
+++ b/examples/webpylike/example.py
@@ -8,23 +8,22 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-from webpylike.webpylike import WebPyApp, View, Response
+from .webpylike import Response
+from .webpylike import View
+from .webpylike import WebPyApp
-urls = (
- '/', 'index',
- '/about', 'about'
-)
+urls = ("/", "index", "/about", "about")
class index(View):
def GET(self):
- return Response('Hello World')
+ return Response("Hello World")
class about(View):
def GET(self):
- return Response('This is the about page')
+ return Response("This is the about page")
app = WebPyApp(urls, globals())
diff --git a/examples/webpylike/webpylike.py b/examples/webpylike/webpylike.py
index 6669ec63..0a7325ca 100644
--- a/examples/webpylike/webpylike.py
+++ b/examples/webpylike/webpylike.py
@@ -11,9 +11,13 @@
:license: BSD-3-Clause
"""
import re
-from werkzeug.wrappers import BaseRequest, BaseResponse
-from werkzeug.exceptions import HTTPException, MethodNotAllowed, \
- NotImplemented, NotFound
+
+from werkzeug.exceptions import HTTPException
+from werkzeug.exceptions import MethodNotAllowed
+from werkzeug.exceptions import NotFound
+from werkzeug.exceptions import NotImplemented
+from werkzeug.wrappers import BaseRequest
+from werkzeug.wrappers import BaseResponse
class Request(BaseRequest):
@@ -33,6 +37,7 @@ class View(object):
def GET(self):
raise MethodNotAllowed()
+
POST = DELETE = PUT = GET
def HEAD(self):
@@ -46,8 +51,9 @@ class WebPyApp(object):
"""
def __init__(self, urls, views):
- self.urls = [(re.compile('^%s$' % urls[i]), urls[i + 1])
- for i in range(0, len(urls), 2)]
+ self.urls = [
+ (re.compile("^%s$" % urls[i]), urls[i + 1]) for i in range(0, len(urls), 2)
+ ]
self.views = views
def __call__(self, environ, start_response):
@@ -57,9 +63,8 @@ class WebPyApp(object):
match = regex.match(req.path)
if match is not None:
view = self.views[view](self, req)
- if req.method not in ('GET', 'HEAD', 'POST',
- 'DELETE', 'PUT'):
- raise NotImplemented()
+ if req.method not in ("GET", "HEAD", "POST", "DELETE", "PUT"):
+ raise NotImplemented() # noqa: F901
resp = getattr(view, req.method)(*match.groups())
break
else:
diff --git a/setup.cfg b/setup.cfg
index aa4240d8..1def1f7b 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -28,8 +28,27 @@ source =
.tox/*/site-packages/werkzeug
[flake8]
-ignore = E126,E241,E272,E305,E402,E731,W503
-exclude=.tox,examples,docs
-max-line-length=100
+# B = bugbear
+# E = pycodestyle errors
+# F = flake8 pyflakes
+# W = pycodestyle warnings
+# B9 = bugbear opinions
+select = B, E, F, W, B9
+ignore =
+ # slice notation whitespace, invalid
+ E203
+ # import at top, too many circular import fixes
+ E402
+ # line length, handled by bugbear B950
+ E501
+ # bare except, handled by bugbear B001
+ E722
+ # bin op line break, invalid
+ W503
+# up to 88 allowed by bugbear B950
+max-line-length = 80
per-file-ignores =
- src/werkzeug/wrappers/__init__.py: F401
+ # __init__ modules export names
+ **/__init__.py: F401
+ # LocalProxy assigns lambdas
+ src/werkzeug/local.py: E731
diff --git a/setup.py b/setup.py
index 68a17131..39836d74 100644
--- a/setup.py
+++ b/setup.py
@@ -1,17 +1,17 @@
import io
import re
-from setuptools import find_packages, setup
+from setuptools import find_packages
+from setuptools import setup
-with io.open('README.rst', 'rt', encoding='utf8') as f:
+with io.open("README.rst", "rt", encoding="utf8") as f:
readme = f.read()
-with io.open('src/werkzeug/__init__.py', 'rt', encoding='utf8') as f:
- version = re.search(
- r'__version__ = \'(.*?)\'', f.read(), re.M).group(1)
+with io.open("src/werkzeug/__init__.py", "rt", encoding="utf8") as f:
+ version = re.search(r'__version__ = "(.*?)"', f.read(), re.M).group(1)
setup(
- name='Werkzeug',
+ name="Werkzeug",
version=version,
url="https://palletsprojects.com/p/werkzeug/",
project_urls={
@@ -19,53 +19,44 @@ setup(
"Code": "https://github.com/pallets/werkzeug",
"Issue tracker": "https://github.com/pallets/werkzeug/issues",
},
- license='BSD-3-Clause',
- author='Armin Ronacher',
- author_email='armin.ronacher@active-4.com',
+ license="BSD-3-Clause",
+ author="Armin Ronacher",
+ author_email="armin.ronacher@active-4.com",
maintainer="The Pallets Team",
maintainer_email="contact@palletsprojects.com",
- description='The comprehensive WSGI web application library.',
+ description="The comprehensive WSGI web application library.",
long_description=readme,
classifiers=[
- 'Development Status :: 5 - Production/Stable',
- 'Environment :: Web Environment',
- 'Intended Audience :: Developers',
- 'License :: OSI Approved :: BSD License',
- 'Operating System :: OS Independent',
- 'Programming Language :: Python',
- 'Programming Language :: Python :: 2',
- 'Programming Language :: Python :: 2.7',
- 'Programming Language :: Python :: 3',
- 'Programming Language :: Python :: 3.4',
- 'Programming Language :: Python :: 3.5',
- 'Programming Language :: Python :: 3.6',
+ "Development Status :: 5 - Production/Stable",
+ "Environment :: Web Environment",
+ "Intended Audience :: Developers",
+ "License :: OSI Approved :: BSD License",
+ "Operating System :: OS Independent",
+ "Programming Language :: Python",
+ "Programming Language :: Python :: 2",
+ "Programming Language :: Python :: 2.7",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.4",
+ "Programming Language :: Python :: 3.5",
+ "Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
- 'Topic :: Internet :: WWW/HTTP :: Dynamic Content',
- 'Topic :: Internet :: WWW/HTTP :: WSGI',
- 'Topic :: Internet :: WWW/HTTP :: WSGI :: Application',
- 'Topic :: Internet :: WWW/HTTP :: WSGI :: Middleware',
- 'Topic :: Software Development :: Libraries :: Application Frameworks',
- 'Topic :: Software Development :: Libraries :: Python Modules',
+ "Topic :: Internet :: WWW/HTTP :: Dynamic Content",
+ "Topic :: Internet :: WWW/HTTP :: WSGI",
+ "Topic :: Internet :: WWW/HTTP :: WSGI :: Application",
+ "Topic :: Internet :: WWW/HTTP :: WSGI :: Middleware",
+ "Topic :: Software Development :: Libraries :: Application Frameworks",
+ "Topic :: Software Development :: Libraries :: Python Modules",
],
packages=find_packages("src"),
package_dir={"": "src"},
include_package_data=True,
python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*",
extras_require={
- 'watchdog': ['watchdog'],
- 'termcolor': ['termcolor'],
- 'dev': [
- 'pytest',
- 'coverage',
- 'tox',
- 'sphinx',
- 'pallets-sphinx-themes',
- ],
- 'docs': [
- 'sphinx',
- 'pallets-sphinx-themes',
- ]
+ "watchdog": ["watchdog"],
+ "termcolor": ["termcolor"],
+ "dev": ["pytest", "coverage", "tox", "sphinx", "pallets-sphinx-themes"],
+ "docs": ["sphinx", "pallets-sphinx-themes"],
},
)
diff --git a/src/werkzeug/__init__.py b/src/werkzeug/__init__.py
index f8d0d447..5a0c4e9e 100644
--- a/src/werkzeug/__init__.py
+++ b/src/werkzeug/__init__.py
@@ -14,12 +14,10 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-from types import ModuleType
import sys
+from types import ModuleType
-from werkzeug._compat import iteritems
-
-__version__ = '0.15.dev'
+__version__ = "0.15.dev"
# This import magic raises concerns quite often which is why the implementation
# and motivation is explained here in detail now.
@@ -34,82 +32,153 @@ __version__ = '0.15.dev'
# werkzeug package when imported from within. Attribute access to the werkzeug
# module will then lazily import from the modules that implement the objects.
-
# import mapping to objects in other modules
all_by_module = {
- 'werkzeug.debug': ['DebuggedApplication'],
- 'werkzeug.local': ['Local', 'LocalManager', 'LocalProxy', 'LocalStack',
- 'release_local'],
- 'werkzeug.serving': ['run_simple'],
- 'werkzeug.test': ['Client', 'EnvironBuilder', 'create_environ',
- 'run_wsgi_app'],
- 'werkzeug.testapp': ['test_app'],
- 'werkzeug.exceptions': ['abort', 'Aborter'],
- 'werkzeug.urls': ['url_decode', 'url_encode', 'url_quote',
- 'url_quote_plus', 'url_unquote', 'url_unquote_plus',
- 'url_fix', 'Href', 'iri_to_uri', 'uri_to_iri'],
- 'werkzeug.formparser': ['parse_form_data'],
- 'werkzeug.utils': ['escape', 'environ_property', 'append_slash_redirect',
- 'redirect', 'cached_property', 'import_string',
- 'dump_cookie', 'parse_cookie', 'unescape',
- 'format_string', 'find_modules', 'header_property',
- 'html', 'xhtml', 'HTMLBuilder', 'validate_arguments',
- 'ArgumentValidationError', 'bind_arguments',
- 'secure_filename'],
- 'werkzeug.wsgi': ['get_current_url', 'get_host', 'pop_path_info',
- 'peek_path_info',
- 'ClosingIterator', 'FileWrapper',
- 'make_line_iter', 'LimitedStream', 'responder',
- 'wrap_file', 'extract_path_info'],
- 'werkzeug.datastructures': ['MultiDict', 'CombinedMultiDict', 'Headers',
- 'EnvironHeaders', 'ImmutableList',
- 'ImmutableDict', 'ImmutableMultiDict',
- 'TypeConversionDict',
- 'ImmutableTypeConversionDict', 'Accept',
- 'MIMEAccept', 'CharsetAccept',
- 'LanguageAccept', 'RequestCacheControl',
- 'ResponseCacheControl', 'ETags', 'HeaderSet',
- 'WWWAuthenticate', 'Authorization',
- 'FileMultiDict', 'CallbackDict', 'FileStorage',
- 'OrderedMultiDict', 'ImmutableOrderedMultiDict'
- ],
- 'werkzeug.useragents': ['UserAgent'],
- 'werkzeug.http': ['parse_etags', 'parse_date', 'http_date', 'cookie_date',
- 'parse_cache_control_header', 'is_resource_modified',
- 'parse_accept_header', 'parse_set_header', 'quote_etag',
- 'unquote_etag', 'generate_etag', 'dump_header',
- 'parse_list_header', 'parse_dict_header',
- 'parse_authorization_header',
- 'parse_www_authenticate_header', 'remove_entity_headers',
- 'is_entity_header', 'remove_hop_by_hop_headers',
- 'parse_options_header', 'dump_options_header',
- 'is_hop_by_hop_header', 'unquote_header_value',
- 'quote_header_value', 'HTTP_STATUS_CODES'],
- 'werkzeug.wrappers': ['BaseResponse', 'BaseRequest', 'Request', 'Response',
- 'AcceptMixin', 'ETagRequestMixin',
- 'ETagResponseMixin', 'ResponseStreamMixin',
- 'CommonResponseDescriptorsMixin', 'UserAgentMixin',
- 'AuthorizationMixin', 'WWWAuthenticateMixin',
- 'CommonRequestDescriptorsMixin'],
+ "werkzeug.debug": ["DebuggedApplication"],
+ "werkzeug.local": [
+ "Local",
+ "LocalManager",
+ "LocalProxy",
+ "LocalStack",
+ "release_local",
+ ],
+ "werkzeug.serving": ["run_simple"],
+ "werkzeug.test": ["Client", "EnvironBuilder", "create_environ", "run_wsgi_app"],
+ "werkzeug.testapp": ["test_app"],
+ "werkzeug.exceptions": ["abort", "Aborter"],
+ "werkzeug.urls": [
+ "url_decode",
+ "url_encode",
+ "url_quote",
+ "url_quote_plus",
+ "url_unquote",
+ "url_unquote_plus",
+ "url_fix",
+ "Href",
+ "iri_to_uri",
+ "uri_to_iri",
+ ],
+ "werkzeug.formparser": ["parse_form_data"],
+ "werkzeug.utils": [
+ "escape",
+ "environ_property",
+ "append_slash_redirect",
+ "redirect",
+ "cached_property",
+ "import_string",
+ "dump_cookie",
+ "parse_cookie",
+ "unescape",
+ "format_string",
+ "find_modules",
+ "header_property",
+ "html",
+ "xhtml",
+ "HTMLBuilder",
+ "validate_arguments",
+ "ArgumentValidationError",
+ "bind_arguments",
+ "secure_filename",
+ ],
+ "werkzeug.wsgi": [
+ "get_current_url",
+ "get_host",
+ "pop_path_info",
+ "peek_path_info",
+ "ClosingIterator",
+ "FileWrapper",
+ "make_line_iter",
+ "LimitedStream",
+ "responder",
+ "wrap_file",
+ "extract_path_info",
+ ],
+ "werkzeug.datastructures": [
+ "MultiDict",
+ "CombinedMultiDict",
+ "Headers",
+ "EnvironHeaders",
+ "ImmutableList",
+ "ImmutableDict",
+ "ImmutableMultiDict",
+ "TypeConversionDict",
+ "ImmutableTypeConversionDict",
+ "Accept",
+ "MIMEAccept",
+ "CharsetAccept",
+ "LanguageAccept",
+ "RequestCacheControl",
+ "ResponseCacheControl",
+ "ETags",
+ "HeaderSet",
+ "WWWAuthenticate",
+ "Authorization",
+ "FileMultiDict",
+ "CallbackDict",
+ "FileStorage",
+ "OrderedMultiDict",
+ "ImmutableOrderedMultiDict",
+ ],
+ "werkzeug.useragents": ["UserAgent"],
+ "werkzeug.http": [
+ "parse_etags",
+ "parse_date",
+ "http_date",
+ "cookie_date",
+ "parse_cache_control_header",
+ "is_resource_modified",
+ "parse_accept_header",
+ "parse_set_header",
+ "quote_etag",
+ "unquote_etag",
+ "generate_etag",
+ "dump_header",
+ "parse_list_header",
+ "parse_dict_header",
+ "parse_authorization_header",
+ "parse_www_authenticate_header",
+ "remove_entity_headers",
+ "is_entity_header",
+ "remove_hop_by_hop_headers",
+ "parse_options_header",
+ "dump_options_header",
+ "is_hop_by_hop_header",
+ "unquote_header_value",
+ "quote_header_value",
+ "HTTP_STATUS_CODES",
+ ],
+ "werkzeug.wrappers": [
+ "BaseResponse",
+ "BaseRequest",
+ "Request",
+ "Response",
+ "AcceptMixin",
+ "ETagRequestMixin",
+ "ETagResponseMixin",
+ "ResponseStreamMixin",
+ "CommonResponseDescriptorsMixin",
+ "UserAgentMixin",
+ "AuthorizationMixin",
+ "WWWAuthenticateMixin",
+ "CommonRequestDescriptorsMixin",
+ ],
"werkzeug.middleware.dispatcher": ["DispatcherMiddleware"],
"werkzeug.middleware.shared_data": ["SharedDataMiddleware"],
- 'werkzeug.security': ['generate_password_hash', 'check_password_hash'],
+ "werkzeug.security": ["generate_password_hash", "check_password_hash"],
# the undocumented easteregg ;-)
- 'werkzeug._internal': ['_easteregg']
+ "werkzeug._internal": ["_easteregg"],
}
# modules that should be imported when accessed as attributes of werkzeug
-attribute_modules = frozenset(['exceptions', 'routing'])
-
+attribute_modules = frozenset(["exceptions", "routing"])
object_origins = {}
-for module, items in iteritems(all_by_module):
+for module, items in all_by_module.items():
for item in items:
object_origins[item] = module
class module(ModuleType):
-
"""Automatically import objects from the modules."""
def __getattr__(self, name):
@@ -119,34 +188,46 @@ class module(ModuleType):
setattr(self, extra_name, getattr(module, extra_name))
return getattr(module, name)
elif name in attribute_modules:
- __import__('werkzeug.' + name)
+ __import__("werkzeug." + name)
return ModuleType.__getattribute__(self, name)
def __dir__(self):
"""Just show what we want to show."""
result = list(new_module.__all__)
- result.extend(('__file__', '__doc__', '__all__',
- '__docformat__', '__name__', '__path__',
- '__package__', '__version__'))
+ result.extend(
+ (
+ "__file__",
+ "__doc__",
+ "__all__",
+ "__docformat__",
+ "__name__",
+ "__path__",
+ "__package__",
+ "__version__",
+ )
+ )
return result
+
# keep a reference to this module so that it's not garbage collected
-old_module = sys.modules['werkzeug']
+old_module = sys.modules["werkzeug"]
# setup the new module and patch it into the dict of loaded modules
-new_module = sys.modules['werkzeug'] = module('werkzeug')
-new_module.__dict__.update({
- '__file__': __file__,
- '__package__': 'werkzeug',
- '__path__': __path__,
- '__doc__': __doc__,
- '__version__': __version__,
- '__all__': tuple(object_origins) + tuple(attribute_modules),
- '__docformat__': 'restructuredtext en'
-})
+new_module = sys.modules["werkzeug"] = module("werkzeug")
+new_module.__dict__.update(
+ {
+ "__file__": __file__,
+ "__package__": "werkzeug",
+ "__path__": __path__,
+ "__doc__": __doc__,
+ "__version__": __version__,
+ "__all__": tuple(object_origins) + tuple(attribute_modules),
+ "__docformat__": "restructuredtext en",
+ }
+)
# Due to bootstrapping issues we need to import exceptions here.
# Don't ask :-(
-__import__('werkzeug.exceptions')
+__import__("werkzeug.exceptions")
diff --git a/src/werkzeug/_compat.py b/src/werkzeug/_compat.py
index f9c5b343..1097983e 100644
--- a/src/werkzeug/_compat.py
+++ b/src/werkzeug/_compat.py
@@ -1,10 +1,8 @@
# flake8: noqa
# This whole file is full of lint errors
-import codecs
-import sys
-import operator
import functools
-import warnings
+import operator
+import sys
try:
import builtins
@@ -13,7 +11,7 @@ except ImportError:
PY2 = sys.version_info[0] == 2
-WIN = sys.platform.startswith('win')
+WIN = sys.platform.startswith("win")
_identity = lambda x: x
@@ -35,15 +33,19 @@ if PY2:
import collections as collections_abc
- exec('def reraise(tp, value, tb=None):\n raise tp, value, tb')
+ exec("def reraise(tp, value, tb=None):\n raise tp, value, tb")
def fix_tuple_repr(obj):
def __repr__(self):
cls = self.__class__
- return '%s(%s)' % (cls.__name__, ', '.join(
- '%s=%r' % (field, self[index])
- for index, field in enumerate(cls._fields)
- ))
+ return "%s(%s)" % (
+ cls.__name__,
+ ", ".join(
+ "%s=%r" % (field, self[index])
+ for index, field in enumerate(cls._fields)
+ ),
+ )
+
obj.__repr__ = __repr__
return obj
@@ -54,12 +56,13 @@ if PY2:
def implements_to_string(cls):
cls.__unicode__ = cls.__str__
- cls.__str__ = lambda x: x.__unicode__().encode('utf-8')
+ cls.__str__ = lambda x: x.__unicode__().encode("utf-8")
return cls
def native_string_result(func):
def wrapper(*args, **kwargs):
- return func(*args, **kwargs).encode('utf-8')
+ return func(*args, **kwargs).encode("utf-8")
+
return functools.update_wrapper(wrapper, func)
def implements_bool(cls):
@@ -68,10 +71,12 @@ if PY2:
return cls
from itertools import imap, izip, ifilter
+
range_type = xrange
from StringIO import StringIO
from cStringIO import StringIO as BytesIO
+
NativeStringIO = BytesIO
def make_literal_wrapper(reference):
@@ -96,33 +101,34 @@ if PY2:
wsgi_get_bytes = _identity
- def wsgi_decoding_dance(s, charset='utf-8', errors='replace'):
+ def wsgi_decoding_dance(s, charset="utf-8", errors="replace"):
return s.decode(charset, errors)
- def wsgi_encoding_dance(s, charset='utf-8', errors='replace'):
+ def wsgi_encoding_dance(s, charset="utf-8", errors="replace"):
if isinstance(s, bytes):
return s
return s.encode(charset, errors)
- def to_bytes(x, charset=sys.getdefaultencoding(), errors='strict'):
+ def to_bytes(x, charset=sys.getdefaultencoding(), errors="strict"):
if x is None:
return None
if isinstance(x, (bytes, bytearray, buffer)):
return bytes(x)
if isinstance(x, unicode):
return x.encode(charset, errors)
- raise TypeError('Expected bytes')
+ raise TypeError("Expected bytes")
- def to_native(x, charset=sys.getdefaultencoding(), errors='strict'):
+ def to_native(x, charset=sys.getdefaultencoding(), errors="strict"):
if x is None or isinstance(x, str):
return x
return x.encode(charset, errors)
+
else:
unichr = chr
text_type = str
- string_types = (str, )
- integer_types = (int, )
+ string_types = (str,)
+ integer_types = (int,)
iterkeys = lambda d, *args, **kwargs: iter(d.keys(*args, **kwargs))
itervalues = lambda d, *args, **kwargs: iter(d.values(*args, **kwargs))
@@ -131,7 +137,7 @@ else:
iterlists = lambda d, *args, **kwargs: iter(d.lists(*args, **kwargs))
iterlistvalues = lambda d, *args, **kwargs: iter(d.listvalues(*args, **kwargs))
- int_to_byte = operator.methodcaller('to_bytes', 1, 'big')
+ int_to_byte = operator.methodcaller("to_bytes", 1, "big")
iter_bytes = functools.partial(map, int_to_byte)
import collections.abc as collections_abc
@@ -152,9 +158,10 @@ else:
range_type = range
from io import StringIO, BytesIO
+
NativeStringIO = StringIO
- _latin1_encode = operator.methodcaller('encode', 'latin1')
+ _latin1_encode = operator.methodcaller("encode", "latin1")
def make_literal_wrapper(reference):
if isinstance(reference, text_type):
@@ -169,38 +176,40 @@ else:
is_text = isinstance(next(tupiter, None), text_type)
for arg in tupiter:
if isinstance(arg, text_type) != is_text:
- raise TypeError('Cannot mix str and bytes arguments (got %s)'
- % repr(tup))
+ raise TypeError(
+ "Cannot mix str and bytes arguments (got %s)" % repr(tup)
+ )
return tup
try_coerce_native = _identity
wsgi_get_bytes = _latin1_encode
- def wsgi_decoding_dance(s, charset='utf-8', errors='replace'):
- return s.encode('latin1').decode(charset, errors)
+ def wsgi_decoding_dance(s, charset="utf-8", errors="replace"):
+ return s.encode("latin1").decode(charset, errors)
- def wsgi_encoding_dance(s, charset='utf-8', errors='replace'):
+ def wsgi_encoding_dance(s, charset="utf-8", errors="replace"):
if isinstance(s, text_type):
s = s.encode(charset)
- return s.decode('latin1', errors)
+ return s.decode("latin1", errors)
- def to_bytes(x, charset=sys.getdefaultencoding(), errors='strict'):
+ def to_bytes(x, charset=sys.getdefaultencoding(), errors="strict"):
if x is None:
return None
if isinstance(x, (bytes, bytearray, memoryview)): # noqa
return bytes(x)
if isinstance(x, str):
return x.encode(charset, errors)
- raise TypeError('Expected bytes')
+ raise TypeError("Expected bytes")
- def to_native(x, charset=sys.getdefaultencoding(), errors='strict'):
+ def to_native(x, charset=sys.getdefaultencoding(), errors="strict"):
if x is None or isinstance(x, str):
return x
return x.decode(charset, errors)
-def to_unicode(x, charset=sys.getdefaultencoding(), errors='strict',
- allow_none_charset=False):
+def to_unicode(
+ x, charset=sys.getdefaultencoding(), errors="strict", allow_none_charset=False
+):
if x is None:
return None
if not isinstance(x, bytes):
diff --git a/src/werkzeug/_internal.py b/src/werkzeug/_internal.py
index 36084438..d8b83363 100644
--- a/src/werkzeug/_internal.py
+++ b/src/werkzeug/_internal.py
@@ -8,41 +8,46 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
+import inspect
import re
import string
-import inspect
-from weakref import WeakKeyDictionary
-from datetime import datetime, date
+from datetime import date
+from datetime import datetime
from itertools import chain
+from weakref import WeakKeyDictionary
-from werkzeug._compat import iter_bytes, text_type, int_to_byte, range_type, \
- integer_types
+from ._compat import int_to_byte
+from ._compat import integer_types
+from ._compat import iter_bytes
+from ._compat import range_type
+from ._compat import text_type
_logger = None
_signature_cache = WeakKeyDictionary()
_epoch_ord = date(1970, 1, 1).toordinal()
-_cookie_params = set((b'expires', b'path', b'comment',
- b'max-age', b'secure', b'httponly',
- b'version'))
-_legal_cookie_chars = (string.ascii_letters
- + string.digits
- + u"/=!#$%&'*+-.^_`|~:").encode('ascii')
-
-_cookie_quoting_map = {
- b',': b'\\054',
- b';': b'\\073',
- b'"': b'\\"',
- b'\\': b'\\\\',
+_cookie_params = {
+ b"expires",
+ b"path",
+ b"comment",
+ b"max-age",
+ b"secure",
+ b"httponly",
+ b"version",
}
-for _i in chain(range_type(32), range_type(127, 256)):
- _cookie_quoting_map[int_to_byte(_i)] = ('\\%03o' % _i).encode('latin1')
+_legal_cookie_chars = (
+ string.ascii_letters + string.digits + u"/=!#$%&'*+-.^_`|~:"
+).encode("ascii")
+_cookie_quoting_map = {b",": b"\\054", b";": b"\\073", b'"': b'\\"', b"\\": b"\\\\"}
+for _i in chain(range_type(32), range_type(127, 256)):
+ _cookie_quoting_map[int_to_byte(_i)] = ("\\%03o" % _i).encode("latin1")
-_octal_re = re.compile(br'\\[0-3][0-7][0-7]')
-_quote_re = re.compile(br'[\\].')
-_legal_cookie_chars_re = br'[\w\d!#%&\'~_`><@,:/\$\*\+\-\.\^\|\)\(\?\}\{\=]'
-_cookie_re = re.compile(br"""
+_octal_re = re.compile(br"\\[0-3][0-7][0-7]")
+_quote_re = re.compile(br"[\\].")
+_legal_cookie_chars_re = br"[\w\d!#%&\'~_`><@,:/\$\*\+\-\.\^\|\)\(\?\}\{\=]"
+_cookie_re = re.compile(
+ br"""
(?P<key>[^=;]+)
(?:\s*=\s*
(?P<val>
@@ -51,24 +56,27 @@ _cookie_re = re.compile(br"""
)
)?
\s*;
-""", flags=re.VERBOSE)
+""",
+ flags=re.VERBOSE,
+)
class _Missing(object):
-
def __repr__(self):
- return 'no value'
+ return "no value"
def __reduce__(self):
- return '_missing'
+ return "_missing"
+
_missing = _Missing()
def _get_environ(obj):
- env = getattr(obj, 'environ', obj)
- assert isinstance(env, dict), \
- '%r is not a WSGI environment (has to be a dict)' % type(obj).__name__
+ env = getattr(obj, "environ", obj)
+ assert isinstance(env, dict), (
+ "%r is not a WSGI environment (has to be a dict)" % type(obj).__name__
+ )
return env
@@ -77,7 +85,8 @@ def _log(type, message, *args, **kwargs):
global _logger
if _logger is None:
import logging
- _logger = logging.getLogger('werkzeug')
+
+ _logger = logging.getLogger("werkzeug")
if _logger.level == logging.NOTSET:
_logger.setLevel(logging.INFO)
# Only set up a default log handler if the
@@ -90,7 +99,7 @@ def _log(type, message, *args, **kwargs):
def _parse_signature(func):
"""Return a signature object for the function."""
- if hasattr(func, 'im_func'):
+ if hasattr(func, "im_func"):
func = func.im_func
# if we have a cached validator for this function, return it
@@ -99,7 +108,7 @@ def _parse_signature(func):
return parse
# inspect the function signature and collect all the information
- if hasattr(inspect, 'getfullargspec'):
+ if hasattr(inspect, "getfullargspec"):
tup = inspect.getfullargspec(func)
else:
tup = inspect.getargspec(func)
@@ -109,8 +118,9 @@ def _parse_signature(func):
arguments = []
for idx, name in enumerate(positional):
if isinstance(name, list):
- raise TypeError('cannot parse functions that unpack tuples '
- 'in the function signature')
+ raise TypeError(
+ "cannot parse functions that unpack tuples in the function signature"
+ )
try:
default = defaults[idx - arg_count]
except IndexError:
@@ -150,8 +160,17 @@ def _parse_signature(func):
extra.update(kwargs)
kwargs = {}
- return new_args, kwargs, missing, extra, extra_positional, \
- arguments, vararg_var, kwarg_var
+ return (
+ new_args,
+ kwargs,
+ missing,
+ extra,
+ extra_positional,
+ arguments,
+ vararg_var,
+ kwarg_var,
+ )
+
_signature_cache[func] = parse
return parse
@@ -173,12 +192,19 @@ def _date_to_unix(arg):
class _DictAccessorProperty(object):
-
"""Baseclass for `environ_property` and `header_property`."""
+
read_only = False
- def __init__(self, name, default=None, load_func=None, dump_func=None,
- read_only=None, doc=None):
+ def __init__(
+ self,
+ name,
+ default=None,
+ load_func=None,
+ dump_func=None,
+ read_only=None,
+ doc=None,
+ ):
self.name = name
self.default = default
self.load_func = load_func
@@ -203,21 +229,18 @@ class _DictAccessorProperty(object):
def __set__(self, obj, value):
if self.read_only:
- raise AttributeError('read only property')
+ raise AttributeError("read only property")
if self.dump_func is not None:
value = self.dump_func(value)
self.lookup(obj)[self.name] = value
def __delete__(self, obj):
if self.read_only:
- raise AttributeError('read only property')
+ raise AttributeError("read only property")
self.lookup(obj).pop(self.name, None)
def __repr__(self):
- return '<%s %s>' % (
- self.__class__.__name__,
- self.name
- )
+ return "<%s %s>" % (self.__class__.__name__, self.name)
def _cookie_quote(b):
@@ -263,11 +286,11 @@ def _cookie_unquote(b):
k = q_match.start(0)
if q_match and (not o_match or k < j):
_push(b[i:k])
- _push(b[k + 1:k + 2])
+ _push(b[k + 1 : k + 2])
i = k + 2
else:
_push(b[i:j])
- rv.append(int(b[j + 1:j + 4], 8))
+ rv.append(int(b[j + 1 : j + 4], 8))
i = j + 4
return bytes(rv)
@@ -279,12 +302,12 @@ def _cookie_parse_impl(b):
n = len(b)
while i < n:
- match = _cookie_re.search(b + b';', i)
+ match = _cookie_re.search(b + b";", i)
if not match:
break
- key = match.group('key').strip()
- value = match.group('val') or b''
+ key = match.group("key").strip()
+ value = match.group("val") or b""
i = match.end(0)
# Ignore parameters. We have no interest in them.
@@ -295,20 +318,20 @@ def _cookie_parse_impl(b):
def _encode_idna(domain):
# If we're given bytes, make sure they fit into ASCII
if not isinstance(domain, text_type):
- domain.decode('ascii')
+ domain.decode("ascii")
return domain
# Otherwise check if it's already ascii, then return
try:
- return domain.encode('ascii')
+ return domain.encode("ascii")
except UnicodeError:
pass
# Otherwise encode each part separately
- parts = domain.split('.')
+ parts = domain.split(".")
for idx, part in enumerate(parts):
- parts[idx] = part.encode('idna')
- return b'.'.join(parts)
+ parts[idx] = part.encode("idna")
+ return b".".join(parts)
def _decode_idna(domain):
@@ -317,47 +340,54 @@ def _decode_idna(domain):
# unicode error, then we already have a decoded idna domain
if isinstance(domain, text_type):
try:
- domain = domain.encode('ascii')
+ domain = domain.encode("ascii")
except UnicodeError:
return domain
# Decode each part separately. If a part fails, try to
# decode it with ascii and silently ignore errors. This makes
# most sense because the idna codec does not have error handling
- parts = domain.split(b'.')
+ parts = domain.split(b".")
for idx, part in enumerate(parts):
try:
- parts[idx] = part.decode('idna')
+ parts[idx] = part.decode("idna")
except UnicodeError:
- parts[idx] = part.decode('ascii', 'ignore')
+ parts[idx] = part.decode("ascii", "ignore")
- return '.'.join(parts)
+ return ".".join(parts)
def _make_cookie_domain(domain):
if domain is None:
return None
domain = _encode_idna(domain)
- if b':' in domain:
- domain = domain.split(b':', 1)[0]
- if b'.' in domain:
+ if b":" in domain:
+ domain = domain.split(b":", 1)[0]
+ if b"." in domain:
return domain
raise ValueError(
- 'Setting \'domain\' for a cookie on a server running locally (ex: '
- 'localhost) is not supported by complying browsers. You should '
- 'have something like: \'127.0.0.1 localhost dev.localhost\' on '
- 'your hosts file and then point your server to run on '
- '\'dev.localhost\' and also set \'domain\' for \'dev.localhost\''
+ "Setting 'domain' for a cookie on a server running locally (ex: "
+ "localhost) is not supported by complying browsers. You should "
+ "have something like: '127.0.0.1 localhost dev.localhost' on "
+ "your hosts file and then point your server to run on "
+ "'dev.localhost' and also set 'domain' for 'dev.localhost'"
)
def _easteregg(app=None):
"""Like the name says. But who knows how it works?"""
+
def bzzzzzzz(gyver):
import base64
import zlib
- return zlib.decompress(base64.b64decode(gyver)).decode('ascii')
- gyver = u'\n'.join([x + (77 - len(x)) * u' ' for x in bzzzzzzz(b'''
+
+ return zlib.decompress(base64.b64decode(gyver)).decode("ascii")
+
+ gyver = u"\n".join(
+ [
+ x + (77 - len(x)) * u" "
+ for x in bzzzzzzz(
+ b"""
eJyFlzuOJDkMRP06xRjymKgDJCDQStBYT8BCgK4gTwfQ2fcFs2a2FzvZk+hvlcRvRJD148efHt9m
9Xz94dRY5hGt1nrYcXx7us9qlcP9HHNh28rz8dZj+q4rynVFFPdlY4zH873NKCexrDM6zxxRymzz
4QIxzK4bth1PV7+uHn6WXZ5C4ka/+prFzx3zWLMHAVZb8RRUxtFXI5DTQ2n3Hi2sNI+HK43AOWSY
@@ -388,16 +418,22 @@ p1qXK3Du2mnr5INXmT/78KI12n11EFBkJHHp0wJyLe9MvPNUGYsf+170maayRoy2lURGHAIapSpQ
krEDuNoJCHNlZYhKpvw4mspVWxqo415n8cD62N9+EfHrAvqQnINStetek7RY2Urv8nxsnGaZfRr/
nhXbJ6m/yl1LzYqscDZA9QHLNbdaSTTr+kFg3bC0iYbX/eQy0Bv3h4B50/SGYzKAXkCeOLI3bcAt
mj2Z/FM1vQWgDynsRwNvrWnJHlespkrp8+vO1jNaibm+PhqXPPv30YwDZ6jApe3wUjFQobghvW9p
-7f2zLkGNv8b191cD/3vs9Q833z8t''').splitlines()])
+7f2zLkGNv8b191cD/3vs9Q833z8t"""
+ ).splitlines()
+ ]
+ )
def easteregged(environ, start_response):
def injecting_start_response(status, headers, exc_info=None):
- headers.append(('X-Powered-By', 'Werkzeug'))
+ headers.append(("X-Powered-By", "Werkzeug"))
return start_response(status, headers, exc_info)
- if app is not None and environ.get('QUERY_STRING') != 'macgybarchakku':
+
+ if app is not None and environ.get("QUERY_STRING") != "macgybarchakku":
return app(environ, injecting_start_response)
- injecting_start_response('200 OK', [('Content-Type', 'text/html')])
- return [(u'''
+ injecting_start_response("200 OK", [("Content-Type", "text/html")])
+ return [
+ (
+ u"""
<!DOCTYPE html>
<html>
<head>
@@ -415,5 +451,9 @@ mj2Z/FM1vQWgDynsRwNvrWnJHlespkrp8+vO1jNaibm+PhqXPPv30YwDZ6jApe3wUjFQobghvW9p
<p>the Swiss Army knife of Python web development.</p>
<pre>%s\n\n\n</pre>
</body>
-</html>''' % gyver).encode('latin1')]
+</html>"""
+ % gyver
+ ).encode("latin1")
+ ]
+
return easteregged
diff --git a/src/werkzeug/_reloader.py b/src/werkzeug/_reloader.py
index 2b21e146..f06a63d5 100644
--- a/src/werkzeug/_reloader.py
+++ b/src/werkzeug/_reloader.py
@@ -1,12 +1,14 @@
import os
-import sys
-import time
import subprocess
+import sys
import threading
+import time
from itertools import chain
-from werkzeug._internal import _log
-from werkzeug._compat import PY2, iteritems, text_type
+from ._compat import iteritems
+from ._compat import PY2
+from ._compat import text_type
+from ._internal import _log
def _iter_module_files():
@@ -19,10 +21,11 @@ def _iter_module_files():
for module in list(sys.modules.values()):
if module is None:
continue
- filename = getattr(module, '__file__', None)
+ filename = getattr(module, "__file__", None)
if filename:
- if os.path.isdir(filename) and \
- os.path.exists(os.path.join(filename, "__init__.py")):
+ if os.path.isdir(filename) and os.path.exists(
+ os.path.join(filename, "__init__.py")
+ ):
filename = os.path.join(filename, "__init__.py")
old = None
@@ -32,22 +35,23 @@ def _iter_module_files():
if filename == old:
break
else:
- if filename[-4:] in ('.pyc', '.pyo'):
+ if filename[-4:] in (".pyc", ".pyo"):
filename = filename[:-1]
yield filename
def _find_observable_paths(extra_files=None):
"""Finds all paths that should be observed."""
- rv = set(os.path.dirname(os.path.abspath(x))
- if os.path.isfile(x) else os.path.abspath(x)
- for x in sys.path)
+ rv = set(
+ os.path.dirname(os.path.abspath(x)) if os.path.isfile(x) else os.path.abspath(x)
+ for x in sys.path
+ )
for filename in extra_files or ():
rv.add(os.path.dirname(os.path.abspath(filename)))
for module in list(sys.modules.values()):
- fn = getattr(module, '__file__', None)
+ fn = getattr(module, "__file__", None)
if fn is None:
continue
fn = os.path.abspath(fn)
@@ -125,7 +129,8 @@ def _find_common_roots(paths):
for prefix, child in iteritems(node):
_walk(child, path + (prefix,))
if not node:
- rv.add('/'.join(path))
+ rv.add("/".join(path))
+
_walk(root, ())
return rv
@@ -139,8 +144,7 @@ class ReloaderLoop(object):
_sleep = staticmethod(time.sleep)
def __init__(self, extra_files=None, interval=1):
- self.extra_files = set(os.path.abspath(x)
- for x in extra_files or ())
+ self.extra_files = set(os.path.abspath(x) for x in extra_files or ())
self.interval = interval
def run(self):
@@ -151,26 +155,25 @@ class ReloaderLoop(object):
but running the reloader thread.
"""
while 1:
- _log('info', ' * Restarting with %s' % self.name)
+ _log("info", " * Restarting with %s" % self.name)
args = _get_args_for_reloading()
# a weird bug on windows. sometimes unicode strings end up in the
# environment and subprocess.call does not like this, encode them
# to latin1 and continue.
- if os.name == 'nt' and PY2:
+ if os.name == "nt" and PY2:
new_environ = {}
for key, value in iteritems(os.environ):
if isinstance(key, text_type):
- key = key.encode('iso-8859-1')
+ key = key.encode("iso-8859-1")
if isinstance(value, text_type):
- value = value.encode('iso-8859-1')
+ value = value.encode("iso-8859-1")
new_environ[key] = value
else:
new_environ = os.environ.copy()
- new_environ['WERKZEUG_RUN_MAIN'] = 'true'
- exit_code = subprocess.call(args, env=new_environ,
- close_fds=False)
+ new_environ["WERKZEUG_RUN_MAIN"] = "true"
+ exit_code = subprocess.call(args, env=new_environ, close_fds=False)
if exit_code != 3:
return exit_code
@@ -180,17 +183,16 @@ class ReloaderLoop(object):
def log_reload(self, filename):
filename = os.path.abspath(filename)
- _log('info', ' * Detected change in %r, reloading' % filename)
+ _log("info", " * Detected change in %r, reloading" % filename)
class StatReloaderLoop(ReloaderLoop):
- name = 'stat'
+ name = "stat"
def run(self):
mtimes = {}
while 1:
- for filename in chain(_iter_module_files(),
- self.extra_files):
+ for filename in chain(_iter_module_files(), self.extra_files):
try:
mtime = os.stat(filename).st_mtime
except OSError:
@@ -206,11 +208,11 @@ class StatReloaderLoop(ReloaderLoop):
class WatchdogReloaderLoop(ReloaderLoop):
-
def __init__(self, *args, **kwargs):
ReloaderLoop.__init__(self, *args, **kwargs)
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
+
self.observable_paths = set()
def _check_modification(filename):
@@ -218,11 +220,10 @@ class WatchdogReloaderLoop(ReloaderLoop):
self.trigger_reload(filename)
dirname = os.path.dirname(filename)
if dirname.startswith(tuple(self.observable_paths)):
- if filename.endswith(('.pyc', '.pyo', '.py')):
+ if filename.endswith((".pyc", ".pyo", ".py")):
self.trigger_reload(filename)
class _CustomHandler(FileSystemEventHandler):
-
def on_created(self, event):
_check_modification(event.src_path)
@@ -237,9 +238,9 @@ class WatchdogReloaderLoop(ReloaderLoop):
_check_modification(event.src_path)
reloader_name = Observer.__name__.lower()
- if reloader_name.endswith('observer'):
+ if reloader_name.endswith("observer"):
reloader_name = reloader_name[:-8]
- reloader_name += ' reloader'
+ reloader_name += " reloader"
self.name = reloader_name
@@ -267,7 +268,8 @@ class WatchdogReloaderLoop(ReloaderLoop):
if path not in watches:
try:
watches[path] = observer.schedule(
- self.event_handler, path, recursive=True)
+ self.event_handler, path, recursive=True
+ )
except OSError:
# Clear this path from list of watches We don't want
# the same error message showing again in the next
@@ -287,17 +289,14 @@ class WatchdogReloaderLoop(ReloaderLoop):
sys.exit(3)
-reloader_loops = {
- 'stat': StatReloaderLoop,
- 'watchdog': WatchdogReloaderLoop,
-}
+reloader_loops = {"stat": StatReloaderLoop, "watchdog": WatchdogReloaderLoop}
try:
- __import__('watchdog.observers')
+ __import__("watchdog.observers")
except ImportError:
- reloader_loops['auto'] = reloader_loops['stat']
+ reloader_loops["auto"] = reloader_loops["stat"]
else:
- reloader_loops['auto'] = reloader_loops['watchdog']
+ reloader_loops["auto"] = reloader_loops["watchdog"]
def ensure_echo_on():
@@ -316,14 +315,14 @@ def ensure_echo_on():
termios.tcsetattr(sys.stdin, termios.TCSANOW, attributes)
-def run_with_reloader(main_func, extra_files=None, interval=1,
- reloader_type='auto'):
+def run_with_reloader(main_func, extra_files=None, interval=1, reloader_type="auto"):
"""Run the given function in an independent python interpreter."""
import signal
+
reloader = reloader_loops[reloader_type](extra_files, interval)
signal.signal(signal.SIGTERM, lambda *args: sys.exit(0))
try:
- if os.environ.get('WERKZEUG_RUN_MAIN') == 'true':
+ if os.environ.get("WERKZEUG_RUN_MAIN") == "true":
ensure_echo_on()
t = threading.Thread(target=main_func, args=())
t.setDaemon(True)
diff --git a/src/werkzeug/contrib/atom.py b/src/werkzeug/contrib/atom.py
index 5c631f43..d079d2bf 100644
--- a/src/werkzeug/contrib/atom.py
+++ b/src/werkzeug/contrib/atom.py
@@ -23,9 +23,11 @@
"""
import warnings
from datetime import datetime
-from werkzeug.utils import escape
-from werkzeug.wrappers import BaseResponse
-from werkzeug._compat import implements_to_string, string_types
+
+from .._compat import implements_to_string
+from .._compat import string_types
+from ..utils import escape
+from ..wrappers import BaseResponse
warnings.warn(
"'werkzeug.contrib.atom' is deprecated as of version 0.15 and will"
@@ -34,18 +36,21 @@ warnings.warn(
stacklevel=2,
)
-XHTML_NAMESPACE = 'http://www.w3.org/1999/xhtml'
+XHTML_NAMESPACE = "http://www.w3.org/1999/xhtml"
def _make_text_block(name, content, content_type=None):
"""Helper function for the builder that creates an XML text block."""
- if content_type == 'xhtml':
- return u'<%s type="xhtml"><div xmlns="%s">%s</div></%s>\n' % \
- (name, XHTML_NAMESPACE, content, name)
+ if content_type == "xhtml":
+ return u'<%s type="xhtml"><div xmlns="%s">%s</div></%s>\n' % (
+ name,
+ XHTML_NAMESPACE,
+ content,
+ name,
+ )
if not content_type:
- return u'<%s>%s</%s>\n' % (name, escape(content), name)
- return u'<%s type="%s">%s</%s>\n' % (name, content_type,
- escape(content), name)
+ return u"<%s>%s</%s>\n" % (name, escape(content), name)
+ return u'<%s type="%s">%s</%s>\n' % (name, content_type, escape(content), name)
def format_iso8601(obj):
@@ -53,7 +58,7 @@ def format_iso8601(obj):
iso8601 = obj.isoformat()
if obj.tzinfo:
return iso8601
- return iso8601 + 'Z'
+ return iso8601 + "Z"
@implements_to_string
@@ -106,42 +111,43 @@ class AtomFeed(object):
Everywhere where a list is demanded, any iterable can be used.
"""
- default_generator = ('Werkzeug', None, None)
+ default_generator = ("Werkzeug", None, None)
def __init__(self, title=None, entries=None, **kwargs):
self.title = title
- self.title_type = kwargs.get('title_type', 'text')
- self.url = kwargs.get('url')
- self.feed_url = kwargs.get('feed_url', self.url)
- self.id = kwargs.get('id', self.feed_url)
- self.updated = kwargs.get('updated')
- self.author = kwargs.get('author', ())
- self.icon = kwargs.get('icon')
- self.logo = kwargs.get('logo')
- self.rights = kwargs.get('rights')
- self.rights_type = kwargs.get('rights_type')
- self.subtitle = kwargs.get('subtitle')
- self.subtitle_type = kwargs.get('subtitle_type', 'text')
- self.generator = kwargs.get('generator')
+ self.title_type = kwargs.get("title_type", "text")
+ self.url = kwargs.get("url")
+ self.feed_url = kwargs.get("feed_url", self.url)
+ self.id = kwargs.get("id", self.feed_url)
+ self.updated = kwargs.get("updated")
+ self.author = kwargs.get("author", ())
+ self.icon = kwargs.get("icon")
+ self.logo = kwargs.get("logo")
+ self.rights = kwargs.get("rights")
+ self.rights_type = kwargs.get("rights_type")
+ self.subtitle = kwargs.get("subtitle")
+ self.subtitle_type = kwargs.get("subtitle_type", "text")
+ self.generator = kwargs.get("generator")
if self.generator is None:
self.generator = self.default_generator
- self.links = kwargs.get('links', [])
+ self.links = kwargs.get("links", [])
self.entries = list(entries) if entries else []
- if not hasattr(self.author, '__iter__') \
- or isinstance(self.author, string_types + (dict,)):
+ if not hasattr(self.author, "__iter__") or isinstance(
+ self.author, string_types + (dict,)
+ ):
self.author = [self.author]
for i, author in enumerate(self.author):
if not isinstance(author, dict):
- self.author[i] = {'name': author}
+ self.author[i] = {"name": author}
if not self.title:
- raise ValueError('title is required')
+ raise ValueError("title is required")
if not self.id:
- raise ValueError('id is required')
+ raise ValueError("id is required")
for author in self.author:
- if 'name' not in author:
- raise TypeError('author must contain at least a name')
+ if "name" not in author:
+ raise TypeError("author must contain at least a name")
def add(self, *args, **kwargs):
"""Add a new entry to the feed. This function can either be called
@@ -151,14 +157,14 @@ class AtomFeed(object):
if len(args) == 1 and not kwargs and isinstance(args[0], FeedEntry):
self.entries.append(args[0])
else:
- kwargs['feed_url'] = self.feed_url
+ kwargs["feed_url"] = self.feed_url
self.entries.append(FeedEntry(*args, **kwargs))
def __repr__(self):
- return '<%s %r (%d entries)>' % (
+ return "<%s %r (%d entries)>" % (
self.__class__.__name__,
self.title,
- len(self.entries)
+ len(self.entries),
)
def generate(self):
@@ -166,7 +172,7 @@ class AtomFeed(object):
# atom demands either an author element in every entry or a global one
if not self.author:
if any(not e.author for e in self.entries):
- self.author = ({'name': 'Unknown author'},)
+ self.author = ({"name": "Unknown author"},)
if not self.updated:
dates = sorted([entry.updated for entry in self.entries])
@@ -174,56 +180,54 @@ class AtomFeed(object):
yield u'<?xml version="1.0" encoding="utf-8"?>\n'
yield u'<feed xmlns="http://www.w3.org/2005/Atom">\n'
- yield ' ' + _make_text_block('title', self.title, self.title_type)
- yield u' <id>%s</id>\n' % escape(self.id)
- yield u' <updated>%s</updated>\n' % format_iso8601(self.updated)
+ yield " " + _make_text_block("title", self.title, self.title_type)
+ yield u" <id>%s</id>\n" % escape(self.id)
+ yield u" <updated>%s</updated>\n" % format_iso8601(self.updated)
if self.url:
yield u' <link href="%s" />\n' % escape(self.url)
if self.feed_url:
- yield u' <link href="%s" rel="self" />\n' % \
- escape(self.feed_url)
+ yield u' <link href="%s" rel="self" />\n' % escape(self.feed_url)
for link in self.links:
- yield u' <link %s/>\n' % ''.join('%s="%s" ' %
- (k, escape(link[k])) for k in link)
+ yield u" <link %s/>\n" % "".join(
+ '%s="%s" ' % (k, escape(link[k])) for k in link
+ )
for author in self.author:
- yield u' <author>\n'
- yield u' <name>%s</name>\n' % escape(author['name'])
- if 'uri' in author:
- yield u' <uri>%s</uri>\n' % escape(author['uri'])
- if 'email' in author:
- yield ' <email>%s</email>\n' % escape(author['email'])
- yield ' </author>\n'
+ yield u" <author>\n"
+ yield u" <name>%s</name>\n" % escape(author["name"])
+ if "uri" in author:
+ yield u" <uri>%s</uri>\n" % escape(author["uri"])
+ if "email" in author:
+ yield " <email>%s</email>\n" % escape(author["email"])
+ yield " </author>\n"
if self.subtitle:
- yield ' ' + _make_text_block('subtitle', self.subtitle,
- self.subtitle_type)
+ yield " " + _make_text_block("subtitle", self.subtitle, self.subtitle_type)
if self.icon:
- yield u' <icon>%s</icon>\n' % escape(self.icon)
+ yield u" <icon>%s</icon>\n" % escape(self.icon)
if self.logo:
- yield u' <logo>%s</logo>\n' % escape(self.logo)
+ yield u" <logo>%s</logo>\n" % escape(self.logo)
if self.rights:
- yield ' ' + _make_text_block('rights', self.rights,
- self.rights_type)
+ yield " " + _make_text_block("rights", self.rights, self.rights_type)
generator_name, generator_url, generator_version = self.generator
if generator_name or generator_url or generator_version:
- tmp = [u' <generator']
+ tmp = [u" <generator"]
if generator_url:
tmp.append(u' uri="%s"' % escape(generator_url))
if generator_version:
tmp.append(u' version="%s"' % escape(generator_version))
- tmp.append(u'>%s</generator>\n' % escape(generator_name))
- yield u''.join(tmp)
+ tmp.append(u">%s</generator>\n" % escape(generator_name))
+ yield u"".join(tmp)
for entry in self.entries:
for line in entry.generate():
- yield u' ' + line
- yield u'</feed>\n'
+ yield u" " + line
+ yield u"</feed>\n"
def to_string(self):
"""Convert the feed into a string."""
- return u''.join(self.generate())
+ return u"".join(self.generate())
def get_response(self):
"""Return a response object for the feed."""
- return BaseResponse(self.to_string(), mimetype='application/atom+xml')
+ return BaseResponse(self.to_string(), mimetype="application/atom+xml")
def __call__(self, environ, start_response):
"""Use the class as WSGI response object."""
@@ -282,80 +286,77 @@ class FeedEntry(object):
def __init__(self, title=None, content=None, feed_url=None, **kwargs):
self.title = title
- self.title_type = kwargs.get('title_type', 'text')
+ self.title_type = kwargs.get("title_type", "text")
self.content = content
- self.content_type = kwargs.get('content_type', 'html')
- self.url = kwargs.get('url')
- self.id = kwargs.get('id', self.url)
- self.updated = kwargs.get('updated')
- self.summary = kwargs.get('summary')
- self.summary_type = kwargs.get('summary_type', 'html')
- self.author = kwargs.get('author', ())
- self.published = kwargs.get('published')
- self.rights = kwargs.get('rights')
- self.links = kwargs.get('links', [])
- self.categories = kwargs.get('categories', [])
- self.xml_base = kwargs.get('xml_base', feed_url)
-
- if not hasattr(self.author, '__iter__') \
- or isinstance(self.author, string_types + (dict,)):
+ self.content_type = kwargs.get("content_type", "html")
+ self.url = kwargs.get("url")
+ self.id = kwargs.get("id", self.url)
+ self.updated = kwargs.get("updated")
+ self.summary = kwargs.get("summary")
+ self.summary_type = kwargs.get("summary_type", "html")
+ self.author = kwargs.get("author", ())
+ self.published = kwargs.get("published")
+ self.rights = kwargs.get("rights")
+ self.links = kwargs.get("links", [])
+ self.categories = kwargs.get("categories", [])
+ self.xml_base = kwargs.get("xml_base", feed_url)
+
+ if not hasattr(self.author, "__iter__") or isinstance(
+ self.author, string_types + (dict,)
+ ):
self.author = [self.author]
for i, author in enumerate(self.author):
if not isinstance(author, dict):
- self.author[i] = {'name': author}
+ self.author[i] = {"name": author}
if not self.title:
- raise ValueError('title is required')
+ raise ValueError("title is required")
if not self.id:
- raise ValueError('id is required')
+ raise ValueError("id is required")
if not self.updated:
- raise ValueError('updated is required')
+ raise ValueError("updated is required")
def __repr__(self):
- return '<%s %r>' % (
- self.__class__.__name__,
- self.title
- )
+ return "<%s %r>" % (self.__class__.__name__, self.title)
def generate(self):
"""Yields pieces of ATOM XML."""
- base = ''
+ base = ""
if self.xml_base:
base = ' xml:base="%s"' % escape(self.xml_base)
- yield u'<entry%s>\n' % base
- yield u' ' + _make_text_block('title', self.title, self.title_type)
- yield u' <id>%s</id>\n' % escape(self.id)
- yield u' <updated>%s</updated>\n' % format_iso8601(self.updated)
+ yield u"<entry%s>\n" % base
+ yield u" " + _make_text_block("title", self.title, self.title_type)
+ yield u" <id>%s</id>\n" % escape(self.id)
+ yield u" <updated>%s</updated>\n" % format_iso8601(self.updated)
if self.published:
- yield u' <published>%s</published>\n' % \
- format_iso8601(self.published)
+ yield u" <published>%s</published>\n" % format_iso8601(self.published)
if self.url:
yield u' <link href="%s" />\n' % escape(self.url)
for author in self.author:
- yield u' <author>\n'
- yield u' <name>%s</name>\n' % escape(author['name'])
- if 'uri' in author:
- yield u' <uri>%s</uri>\n' % escape(author['uri'])
- if 'email' in author:
- yield u' <email>%s</email>\n' % escape(author['email'])
- yield u' </author>\n'
+ yield u" <author>\n"
+ yield u" <name>%s</name>\n" % escape(author["name"])
+ if "uri" in author:
+ yield u" <uri>%s</uri>\n" % escape(author["uri"])
+ if "email" in author:
+ yield u" <email>%s</email>\n" % escape(author["email"])
+ yield u" </author>\n"
for link in self.links:
- yield u' <link %s/>\n' % ''.join('%s="%s" ' %
- (k, escape(link[k])) for k in link)
+ yield u" <link %s/>\n" % "".join(
+ '%s="%s" ' % (k, escape(link[k])) for k in link
+ )
for category in self.categories:
- yield u' <category %s/>\n' % ''.join('%s="%s" ' %
- (k, escape(category[k])) for k in category)
+ yield u" <category %s/>\n" % "".join(
+ '%s="%s" ' % (k, escape(category[k])) for k in category
+ )
if self.summary:
- yield u' ' + _make_text_block('summary', self.summary,
- self.summary_type)
+ yield u" " + _make_text_block("summary", self.summary, self.summary_type)
if self.content:
- yield u' ' + _make_text_block('content', self.content,
- self.content_type)
- yield u'</entry>\n'
+ yield u" " + _make_text_block("content", self.content, self.content_type)
+ yield u"</entry>\n"
def to_string(self):
"""Convert the feed item into a unicode object."""
- return u''.join(self.generate())
+ return u"".join(self.generate())
def __str__(self):
return self.to_string()
diff --git a/src/werkzeug/contrib/cache.py b/src/werkzeug/contrib/cache.py
index 8f5a33ae..79c749b5 100644
--- a/src/werkzeug/contrib/cache.py
+++ b/src/werkzeug/contrib/cache.py
@@ -56,23 +56,27 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
+import errno
import os
+import platform
import re
-import errno
import tempfile
-import platform
import warnings
from hashlib import md5
from time import time
+
+from .._compat import integer_types
+from .._compat import iteritems
+from .._compat import string_types
+from .._compat import text_type
+from .._compat import to_native
+from ..posixemulation import rename
+
try:
import cPickle as pickle
except ImportError: # pragma: no cover
import pickle
-from werkzeug._compat import iteritems, string_types, text_type, \
- integer_types, to_native
-from werkzeug.posixemulation import rename
-
warnings.warn(
"'werkzeug.contrib.cache' is deprecated as of version 0.15 and will"
" be removed in version 1.0. It has moved to https://github.com"
@@ -93,13 +97,12 @@ def _items(mappingorseq):
... assert k*k == v
"""
- if hasattr(mappingorseq, 'items'):
+ if hasattr(mappingorseq, "items"):
return iteritems(mappingorseq)
return mappingorseq
class BaseCache(object):
-
"""Baseclass for the cache systems. All the cache systems implement this
API or a superset of it.
@@ -224,10 +227,10 @@ class BaseCache(object):
:param key: the key to check
"""
raise NotImplementedError(
- '%s doesn\'t have an efficient implementation of `has`. That '
- 'means it is impossible to check whether a key exists without '
- 'fully loading the key\'s data. Consider using `self.get` '
- 'explicitly if you don\'t care about performance.'
+ "%s doesn't have an efficient implementation of `has`. That "
+ "means it is impossible to check whether a key exists without "
+ "fully loading the key's data. Consider using `self.get` "
+ "explicitly if you don't care about performance."
)
def clear(self):
@@ -267,7 +270,6 @@ class BaseCache(object):
class NullCache(BaseCache):
-
"""A cache that doesn't cache. This can be useful for unit testing.
:param default_timeout: a dummy parameter that is ignored but exists
@@ -279,7 +281,6 @@ class NullCache(BaseCache):
class SimpleCache(BaseCache):
-
"""Simple memory cache for single process environments. This class exists
mainly for the development server and is not 100% thread safe. It tries
to use as many atomic operations as possible and no locks for simplicity
@@ -325,15 +326,13 @@ class SimpleCache(BaseCache):
def set(self, key, value, timeout=None):
expires = self._normalize_timeout(timeout)
self._prune()
- self._cache[key] = (expires, pickle.dumps(value,
- pickle.HIGHEST_PROTOCOL))
+ self._cache[key] = (expires, pickle.dumps(value, pickle.HIGHEST_PROTOCOL))
return True
def add(self, key, value, timeout=None):
expires = self._normalize_timeout(timeout)
self._prune()
- item = (expires, pickle.dumps(value,
- pickle.HIGHEST_PROTOCOL))
+ item = (expires, pickle.dumps(value, pickle.HIGHEST_PROTOCOL))
if key in self._cache:
return False
self._cache.setdefault(key, item)
@@ -349,11 +348,11 @@ class SimpleCache(BaseCache):
except KeyError:
return False
-_test_memcached_key = re.compile(r'[^\x00-\x21\xff]{1,250}$').match
+_test_memcached_key = re.compile(r"[^\x00-\x21\xff]{1,250}$").match
-class MemcachedCache(BaseCache):
+class MemcachedCache(BaseCache):
"""A cache that uses memcached as backend.
The first argument can either be an object that resembles the API of a
@@ -392,10 +391,10 @@ class MemcachedCache(BaseCache):
BaseCache.__init__(self, default_timeout)
if servers is None or isinstance(servers, (list, tuple)):
if servers is None:
- servers = ['127.0.0.1:11211']
+ servers = ["127.0.0.1:11211"]
self._client = self.import_preferred_memcache_lib(servers)
if self._client is None:
- raise RuntimeError('no memcache module found')
+ raise RuntimeError("no memcache module found")
else:
# NOTE: servers is actually an already initialized memcache
# client.
@@ -404,7 +403,7 @@ class MemcachedCache(BaseCache):
self.key_prefix = to_native(key_prefix)
def _normalize_key(self, key):
- key = to_native(key, 'utf-8')
+ key = to_native(key, "utf-8")
if self.key_prefix:
key = self.key_prefix + key
return key
@@ -484,7 +483,7 @@ class MemcachedCache(BaseCache):
def has(self, key):
key = self._normalize_key(key)
if _test_memcached_key(key):
- return self._client.append(key, '')
+ return self._client.append(key, "")
return False
def clear(self):
@@ -534,7 +533,6 @@ GAEMemcachedCache = MemcachedCache
class RedisCache(BaseCache):
-
"""Uses the Redis key-value store as a cache backend.
The first argument can be either a string denoting address of the Redis
@@ -570,24 +568,32 @@ class RedisCache(BaseCache):
Any additional keyword arguments will be passed to ``redis.Redis``.
"""
- def __init__(self, host='localhost', port=6379, password=None,
- db=0, default_timeout=300, key_prefix=None, **kwargs):
+ def __init__(
+ self,
+ host="localhost",
+ port=6379,
+ password=None,
+ db=0,
+ default_timeout=300,
+ key_prefix=None,
+ **kwargs
+ ):
BaseCache.__init__(self, default_timeout)
if host is None:
- raise ValueError('RedisCache host parameter may not be None')
+ raise ValueError("RedisCache host parameter may not be None")
if isinstance(host, string_types):
try:
import redis
except ImportError:
- raise RuntimeError('no redis module found')
- if kwargs.get('decode_responses', None):
- raise ValueError('decode_responses is not supported by '
- 'RedisCache.')
- self._client = redis.Redis(host=host, port=port, password=password,
- db=db, **kwargs)
+ raise RuntimeError("no redis module found")
+ if kwargs.get("decode_responses", None):
+ raise ValueError("decode_responses is not supported by RedisCache.")
+ self._client = redis.Redis(
+ host=host, port=port, password=password, db=db, **kwargs
+ )
else:
self._client = host
- self.key_prefix = key_prefix or ''
+ self.key_prefix = key_prefix or ""
def _normalize_timeout(self, timeout):
timeout = BaseCache._normalize_timeout(self, timeout)
@@ -601,8 +607,8 @@ class RedisCache(BaseCache):
"""
t = type(value)
if t in integer_types:
- return str(value).encode('ascii')
- return b'!' + pickle.dumps(value)
+ return str(value).encode("ascii")
+ return b"!" + pickle.dumps(value)
def load_object(self, value):
"""The reversal of :meth:`dump_object`. This might be called with
@@ -610,7 +616,7 @@ class RedisCache(BaseCache):
"""
if value is None:
return None
- if value.startswith(b'!'):
+ if value.startswith(b"!"):
try:
return pickle.loads(value[1:])
except pickle.PickleError:
@@ -633,20 +639,19 @@ class RedisCache(BaseCache):
timeout = self._normalize_timeout(timeout)
dump = self.dump_object(value)
if timeout == -1:
- result = self._client.set(name=self.key_prefix + key,
- value=dump)
+ result = self._client.set(name=self.key_prefix + key, value=dump)
else:
- result = self._client.setex(name=self.key_prefix + key,
- value=dump, time=timeout)
+ result = self._client.setex(
+ name=self.key_prefix + key, value=dump, time=timeout
+ )
return result
def add(self, key, value, timeout=None):
timeout = self._normalize_timeout(timeout)
dump = self.dump_object(value)
- return (
- self._client.setnx(name=self.key_prefix + key, value=dump)
- and self._client.expire(name=self.key_prefix + key, time=timeout)
- )
+ return self._client.setnx(
+ name=self.key_prefix + key, value=dump
+ ) and self._client.expire(name=self.key_prefix + key, time=timeout)
def set_many(self, mapping, timeout=None):
timeout = self._normalize_timeout(timeout)
@@ -659,8 +664,7 @@ class RedisCache(BaseCache):
if timeout == -1:
pipe.set(name=self.key_prefix + key, value=dump)
else:
- pipe.setex(name=self.key_prefix + key, value=dump,
- time=timeout)
+ pipe.setex(name=self.key_prefix + key, value=dump, time=timeout)
return pipe.execute()
def delete(self, key):
@@ -679,7 +683,7 @@ class RedisCache(BaseCache):
def clear(self):
status = False
if self.key_prefix:
- keys = self._client.keys(self.key_prefix + '*')
+ keys = self._client.keys(self.key_prefix + "*")
if keys:
status = self._client.delete(*keys)
else:
@@ -694,7 +698,6 @@ class RedisCache(BaseCache):
class FileSystemCache(BaseCache):
-
"""A cache that stores the items on the file system. This cache depends
on being the only user of the `cache_dir`. Make absolutely sure that
nobody but this cache stores files there or otherwise the cache will
@@ -711,12 +714,11 @@ class FileSystemCache(BaseCache):
"""
#: used for temporary files by the FileSystemCache
- _fs_transaction_suffix = '.__wz_cache'
+ _fs_transaction_suffix = ".__wz_cache"
#: keep amount of files in a cache element
- _fs_count_file = '__wz_cache_count'
+ _fs_count_file = "__wz_cache_count"
- def __init__(self, cache_dir, threshold=500, default_timeout=300,
- mode=0o600):
+ def __init__(self, cache_dir, threshold=500, default_timeout=300, mode=0o600):
BaseCache.__init__(self, default_timeout)
self._path = cache_dir
self._threshold = threshold
@@ -754,11 +756,14 @@ class FileSystemCache(BaseCache):
def _list_dir(self):
"""return a list of (fully qualified) cache filenames
"""
- mgmt_files = [self._get_filename(name).split('/')[-1]
- for name in (self._fs_count_file,)]
- return [os.path.join(self._path, fn) for fn in os.listdir(self._path)
- if not fn.endswith(self._fs_transaction_suffix)
- and fn not in mgmt_files]
+ mgmt_files = [
+ self._get_filename(name).split("/")[-1] for name in (self._fs_count_file,)
+ ]
+ return [
+ os.path.join(self._path, fn)
+ for fn in os.listdir(self._path)
+ if not fn.endswith(self._fs_transaction_suffix) and fn not in mgmt_files
+ ]
def _prune(self):
if self._threshold == 0 or not self._file_count > self._threshold:
@@ -769,7 +774,7 @@ class FileSystemCache(BaseCache):
for idx, fname in enumerate(entries):
try:
remove = False
- with open(fname, 'rb') as f:
+ with open(fname, "rb") as f:
expires = pickle.load(f)
remove = (expires != 0 and expires <= now) or idx % 3 == 0
@@ -791,14 +796,14 @@ class FileSystemCache(BaseCache):
def _get_filename(self, key):
if isinstance(key, text_type):
- key = key.encode('utf-8') # XXX unicode review
+ key = key.encode("utf-8") # XXX unicode review
hash = md5(key).hexdigest()
return os.path.join(self._path, hash)
def get(self, key):
filename = self._get_filename(key)
try:
- with open(filename, 'rb') as f:
+ with open(filename, "rb") as f:
pickle_time = pickle.load(f)
if pickle_time == 0 or pickle_time >= time():
return pickle.load(f)
@@ -826,9 +831,10 @@ class FileSystemCache(BaseCache):
timeout = self._normalize_timeout(timeout)
filename = self._get_filename(key)
try:
- fd, tmp = tempfile.mkstemp(suffix=self._fs_transaction_suffix,
- dir=self._path)
- with os.fdopen(fd, 'wb') as f:
+ fd, tmp = tempfile.mkstemp(
+ suffix=self._fs_transaction_suffix, dir=self._path
+ )
+ with os.fdopen(fd, "wb") as f:
pickle.dump(timeout, f, 1)
pickle.dump(value, f, pickle.HIGHEST_PROTOCOL)
rename(tmp, filename)
@@ -855,7 +861,7 @@ class FileSystemCache(BaseCache):
def has(self, key):
filename = self._get_filename(key)
try:
- with open(filename, 'rb') as f:
+ with open(filename, "rb") as f:
pickle_time = pickle.load(f)
if pickle_time == 0 or pickle_time >= time():
return True
@@ -867,7 +873,7 @@ class FileSystemCache(BaseCache):
class UWSGICache(BaseCache):
- """ Implements the cache using uWSGI's caching framework.
+ """Implements the cache using uWSGI's caching framework.
.. note::
This class cannot be used when running under PyPy, because the uWSGI
@@ -880,19 +886,24 @@ class UWSGICache(BaseCache):
same instance as the werkzeug app, you only have to provide the name of
the cache.
"""
- def __init__(self, default_timeout=300, cache=''):
+
+ def __init__(self, default_timeout=300, cache=""):
BaseCache.__init__(self, default_timeout)
- if platform.python_implementation() == 'PyPy':
- raise RuntimeError("uWSGI caching does not work under PyPy, see "
- "the docs for more details.")
+ if platform.python_implementation() == "PyPy":
+ raise RuntimeError(
+ "uWSGI caching does not work under PyPy, see "
+ "the docs for more details."
+ )
try:
import uwsgi
+
self._uwsgi = uwsgi
except ImportError:
- raise RuntimeError("uWSGI could not be imported, are you "
- "running under uWSGI?")
+ raise RuntimeError(
+ "uWSGI could not be imported, are you running under uWSGI?"
+ )
self.cache = cache
@@ -906,14 +917,14 @@ class UWSGICache(BaseCache):
return self._uwsgi.cache_del(key, self.cache)
def set(self, key, value, timeout=None):
- return self._uwsgi.cache_update(key, pickle.dumps(value),
- self._normalize_timeout(timeout),
- self.cache)
+ return self._uwsgi.cache_update(
+ key, pickle.dumps(value), self._normalize_timeout(timeout), self.cache
+ )
def add(self, key, value, timeout=None):
- return self._uwsgi.cache_set(key, pickle.dumps(value),
- self._normalize_timeout(timeout),
- self.cache)
+ return self._uwsgi.cache_set(
+ key, pickle.dumps(value), self._normalize_timeout(timeout), self.cache
+ )
def clear(self):
return self._uwsgi.cache_clear(self.cache)
diff --git a/src/werkzeug/contrib/fixers.py b/src/werkzeug/contrib/fixers.py
index ad386f57..8df0afda 100644
--- a/src/werkzeug/contrib/fixers.py
+++ b/src/werkzeug/contrib/fixers.py
@@ -28,16 +28,18 @@ This module includes various helpers that fix web server behavior.
"""
import warnings
+from ..datastructures import Headers
+from ..datastructures import ResponseCacheControl
+from ..http import parse_cache_control_header
+from ..http import parse_options_header
+from ..http import parse_set_header
+from ..middleware.proxy_fix import ProxyFix as _ProxyFix
+from ..useragents import UserAgent
+
try:
- from urllib import unquote
-except ImportError:
from urllib.parse import unquote
-
-from werkzeug.http import parse_options_header, parse_cache_control_header, \
- parse_set_header
-from werkzeug.useragents import UserAgent
-from werkzeug.datastructures import Headers, ResponseCacheControl
-from werkzeug.middleware.proxy_fix import ProxyFix as _ProxyFix
+except ImportError:
+ from urllib import unquote
class CGIRootFix(object):
@@ -57,7 +59,7 @@ class CGIRootFix(object):
``LighttpdCGIRootFix``.
"""
- def __init__(self, app, app_root='/'):
+ def __init__(self, app, app_root="/"):
warnings.warn(
"'CGIRootFix' is deprecated as of version 0.15 and will be"
" removed in version 1.0.",
@@ -68,7 +70,7 @@ class CGIRootFix(object):
self.app_root = app_root.strip("/")
def __call__(self, environ, start_response):
- environ['SCRIPT_NAME'] = self.app_root
+ environ["SCRIPT_NAME"] = self.app_root
return self.app(environ, start_response)
@@ -84,7 +86,6 @@ class LighttpdCGIRootFix(CGIRootFix):
class PathInfoFromRequestUriFix(object):
-
"""On windows environment variables are limited to the system charset
which makes it impossible to store the `PATH_INFO` variable in the
environment without loss of information on some systems.
@@ -112,14 +113,13 @@ class PathInfoFromRequestUriFix(object):
self.app = app
def __call__(self, environ, start_response):
- for key in 'REQUEST_URL', 'REQUEST_URI', 'UNENCODED_URL':
+ for key in "REQUEST_URL", "REQUEST_URI", "UNENCODED_URL":
if key not in environ:
continue
request_uri = unquote(environ[key])
- script_name = unquote(environ.get('SCRIPT_NAME', ''))
+ script_name = unquote(environ.get("SCRIPT_NAME", ""))
if request_uri.startswith(script_name):
- environ['PATH_INFO'] = request_uri[len(script_name):] \
- .split('?', 1)[0]
+ environ["PATH_INFO"] = request_uri[len(script_name) :].split("?", 1)[0]
break
return self.app(environ, start_response)
@@ -144,7 +144,6 @@ class ProxyFix(_ProxyFix):
class HeaderRewriterFix(object):
-
"""This middleware can remove response headers and add others. This
is for example useful to remove the `Date` header from responses if you
are using a server that adds that header, no matter if it's present or
@@ -182,11 +181,11 @@ class HeaderRewriterFix(object):
new_headers.append((key, value))
new_headers += self.add_headers
return start_response(status, new_headers, exc_info)
+
return self.app(environ, rewriting_start_response)
class InternetExplorerFix(object):
-
"""This middleware fixes a couple of bugs with Microsoft Internet
Explorer. Currently the following fixes are applied:
@@ -224,40 +223,40 @@ class InternetExplorerFix(object):
def fix_headers(self, environ, headers, status=None):
if self.fix_vary:
- header = headers.get('content-type', '')
+ header = headers.get("content-type", "")
mimetype, options = parse_options_header(header)
- if mimetype not in ('text/html', 'text/plain', 'text/sgml'):
- headers.pop('vary', None)
+ if mimetype not in ("text/html", "text/plain", "text/sgml"):
+ headers.pop("vary", None)
- if self.fix_attach and 'content-disposition' in headers:
- pragma = parse_set_header(headers.get('pragma', ''))
- pragma.discard('no-cache')
+ if self.fix_attach and "content-disposition" in headers:
+ pragma = parse_set_header(headers.get("pragma", ""))
+ pragma.discard("no-cache")
header = pragma.to_header()
if not header:
- headers.pop('pragma', '')
+ headers.pop("pragma", "")
else:
- headers['Pragma'] = header
- header = headers.get('cache-control', '')
+ headers["Pragma"] = header
+ header = headers.get("cache-control", "")
if header:
- cc = parse_cache_control_header(header,
- cls=ResponseCacheControl)
+ cc = parse_cache_control_header(header, cls=ResponseCacheControl)
cc.no_cache = None
cc.no_store = False
header = cc.to_header()
if not header:
- headers.pop('cache-control', '')
+ headers.pop("cache-control", "")
else:
- headers['Cache-Control'] = header
+ headers["Cache-Control"] = header
def run_fixed(self, environ, start_response):
def fixing_start_response(status, headers, exc_info=None):
headers = Headers(headers)
self.fix_headers(environ, headers, status)
return start_response(status, headers.to_wsgi_list(), exc_info)
+
return self.app(environ, fixing_start_response)
def __call__(self, environ, start_response):
ua = UserAgent(environ)
- if ua.browser != 'msie':
+ if ua.browser != "msie":
return self.app(environ, start_response)
return self.run_fixed(environ, start_response)
diff --git a/src/werkzeug/contrib/iterio.py b/src/werkzeug/contrib/iterio.py
index 41e1e20e..b6724540 100644
--- a/src/werkzeug/contrib/iterio.py
+++ b/src/werkzeug/contrib/iterio.py
@@ -41,13 +41,13 @@ r"""
"""
import warnings
+from .._compat import implements_iterator
+
try:
import greenlet
except ImportError:
greenlet = None
-from werkzeug._compat import implements_iterator
-
warnings.warn(
"'werkzeug.contrib.iterio' is deprecated as of version 0.15 and"
" will be removed in version 1.0.",
@@ -61,19 +61,18 @@ def _mixed_join(iterable, sentinel):
iterator = iter(iterable)
first_item = next(iterator, sentinel)
if isinstance(first_item, bytes):
- return first_item + b''.join(iterator)
- return first_item + u''.join(iterator)
+ return first_item + b"".join(iterator)
+ return first_item + u"".join(iterator)
def _newline(reference_string):
if isinstance(reference_string, bytes):
- return b'\n'
- return u'\n'
+ return b"\n"
+ return u"\n"
@implements_iterator
class IterIO(object):
-
"""Instances of this object implement an interface compatible with the
standard Python :class:`file` object. Streams are either read-only or
write-only depending on how the object is created.
@@ -100,7 +99,7 @@ class IterIO(object):
`sentinel` parameter was added.
"""
- def __new__(cls, obj, sentinel=''):
+ def __new__(cls, obj, sentinel=""):
try:
iterator = iter(obj)
except TypeError:
@@ -112,53 +111,53 @@ class IterIO(object):
def tell(self):
if self.closed:
- raise ValueError('I/O operation on closed file')
+ raise ValueError("I/O operation on closed file")
return self.pos
def isatty(self):
if self.closed:
- raise ValueError('I/O operation on closed file')
+ raise ValueError("I/O operation on closed file")
return False
def seek(self, pos, mode=0):
if self.closed:
- raise ValueError('I/O operation on closed file')
- raise IOError(9, 'Bad file descriptor')
+ raise ValueError("I/O operation on closed file")
+ raise IOError(9, "Bad file descriptor")
def truncate(self, size=None):
if self.closed:
- raise ValueError('I/O operation on closed file')
- raise IOError(9, 'Bad file descriptor')
+ raise ValueError("I/O operation on closed file")
+ raise IOError(9, "Bad file descriptor")
def write(self, s):
if self.closed:
- raise ValueError('I/O operation on closed file')
- raise IOError(9, 'Bad file descriptor')
+ raise ValueError("I/O operation on closed file")
+ raise IOError(9, "Bad file descriptor")
def writelines(self, list):
if self.closed:
- raise ValueError('I/O operation on closed file')
- raise IOError(9, 'Bad file descriptor')
+ raise ValueError("I/O operation on closed file")
+ raise IOError(9, "Bad file descriptor")
def read(self, n=-1):
if self.closed:
- raise ValueError('I/O operation on closed file')
- raise IOError(9, 'Bad file descriptor')
+ raise ValueError("I/O operation on closed file")
+ raise IOError(9, "Bad file descriptor")
def readlines(self, sizehint=0):
if self.closed:
- raise ValueError('I/O operation on closed file')
- raise IOError(9, 'Bad file descriptor')
+ raise ValueError("I/O operation on closed file")
+ raise IOError(9, "Bad file descriptor")
def readline(self, length=None):
if self.closed:
- raise ValueError('I/O operation on closed file')
- raise IOError(9, 'Bad file descriptor')
+ raise ValueError("I/O operation on closed file")
+ raise IOError(9, "Bad file descriptor")
def flush(self):
if self.closed:
- raise ValueError('I/O operation on closed file')
- raise IOError(9, 'Bad file descriptor')
+ raise ValueError("I/O operation on closed file")
+ raise IOError(9, "Bad file descriptor")
def __next__(self):
if self.closed:
@@ -170,12 +169,11 @@ class IterIO(object):
class IterI(IterIO):
-
"""Convert an stream into an iterator."""
- def __new__(cls, func, sentinel=''):
+ def __new__(cls, func, sentinel=""):
if greenlet is None:
- raise RuntimeError('IterI requires greenlet support')
+ raise RuntimeError("IterI requires greenlet support")
stream = object.__new__(cls)
stream._parent = greenlet.getcurrent()
stream._buffer = []
@@ -201,7 +199,7 @@ class IterI(IterIO):
def write(self, s):
if self.closed:
- raise ValueError('I/O operation on closed file')
+ raise ValueError("I/O operation on closed file")
if s:
self.pos += len(s)
self._buffer.append(s)
@@ -212,7 +210,7 @@ class IterI(IterIO):
def flush(self):
if self.closed:
- raise ValueError('I/O operation on closed file')
+ raise ValueError("I/O operation on closed file")
self._flush_impl()
def _flush_impl(self):
@@ -225,10 +223,9 @@ class IterI(IterIO):
class IterO(IterIO):
-
"""Iter output. Wrap an iterator and give it a stream like interface."""
- def __new__(cls, gen, sentinel=''):
+ def __new__(cls, gen, sentinel=""):
self = object.__new__(cls)
self._gen = gen
self._buf = None
@@ -241,8 +238,8 @@ class IterO(IterIO):
return self
def _buf_append(self, string):
- '''Replace string directly without appending to an empty string,
- avoiding type issues.'''
+ """Replace string directly without appending to an empty string,
+ avoiding type issues."""
if not self._buf:
self._buf = string
else:
@@ -251,12 +248,12 @@ class IterO(IterIO):
def close(self):
if not self.closed:
self.closed = True
- if hasattr(self._gen, 'close'):
+ if hasattr(self._gen, "close"):
self._gen.close()
def seek(self, pos, mode=0):
if self.closed:
- raise ValueError('I/O operation on closed file')
+ raise ValueError("I/O operation on closed file")
if mode == 1:
pos += self.pos
elif mode == 2:
@@ -264,10 +261,10 @@ class IterO(IterIO):
self.pos = min(self.pos, self.pos + pos)
return
elif mode != 0:
- raise IOError('Invalid argument')
+ raise IOError("Invalid argument")
buf = []
try:
- tmp_end_pos = len(self._buf or '')
+ tmp_end_pos = len(self._buf or "")
while pos > tmp_end_pos:
item = next(self._gen)
tmp_end_pos += len(item)
@@ -280,10 +277,10 @@ class IterO(IterIO):
def read(self, n=-1):
if self.closed:
- raise ValueError('I/O operation on closed file')
+ raise ValueError("I/O operation on closed file")
if n < 0:
self._buf_append(_mixed_join(self._gen, self.sentinel))
- result = self._buf[self.pos:]
+ result = self._buf[self.pos :]
self.pos += len(result)
return result
new_pos = self.pos + n
@@ -304,13 +301,13 @@ class IterO(IterIO):
new_pos = max(0, new_pos)
try:
- return self._buf[self.pos:new_pos]
+ return self._buf[self.pos : new_pos]
finally:
self.pos = min(new_pos, len(self._buf))
def readline(self, length=None):
if self.closed:
- raise ValueError('I/O operation on closed file')
+ raise ValueError("I/O operation on closed file")
nl_pos = -1
if self._buf:
@@ -344,7 +341,7 @@ class IterO(IterIO):
if length is not None and self.pos + length < new_pos:
new_pos = self.pos + length
try:
- return self._buf[self.pos:new_pos]
+ return self._buf[self.pos : new_pos]
finally:
self.pos = min(new_pos, len(self._buf))
diff --git a/src/werkzeug/contrib/securecookie.py b/src/werkzeug/contrib/securecookie.py
index b9ba20ed..c4c9eee2 100644
--- a/src/werkzeug/contrib/securecookie.py
+++ b/src/werkzeug/contrib/securecookie.py
@@ -88,19 +88,22 @@ r"""
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-import pickle
import base64
+import pickle
import warnings
+from hashlib import sha1 as _default_hash
from hmac import new as hmac
from time import time
-from hashlib import sha1 as _default_hash
-from werkzeug._compat import iteritems, text_type, to_bytes
-from werkzeug.urls import url_quote_plus, url_unquote_plus
-from werkzeug._internal import _date_to_unix
-from werkzeug.contrib.sessions import ModificationTrackingDict
-from werkzeug.security import safe_str_cmp
-from werkzeug._compat import to_native
+from .._compat import iteritems
+from .._compat import text_type
+from .._compat import to_bytes
+from .._compat import to_native
+from .._internal import _date_to_unix
+from ..contrib.sessions import ModificationTrackingDict
+from ..security import safe_str_cmp
+from ..urls import url_quote_plus
+from ..urls import url_unquote_plus
warnings.warn(
"'werkzeug.contrib.securecookie' is deprecated as of version 0.15"
@@ -112,12 +115,10 @@ warnings.warn(
class UnquoteError(Exception):
-
"""Internal exception used to signal failures on quoting."""
class SecureCookie(ModificationTrackingDict):
-
"""Represents a secure cookie. You can subclass this class and provide
an alternative mac method. The import thing is that the mac method
is a function with a similar interface to the hashlib. Required
@@ -164,7 +165,7 @@ class SecureCookie(ModificationTrackingDict):
# explicitly convert it into a bytestring because python 2.6
# no longer performs an implicit string conversion on hmac
if secret_key is not None:
- secret_key = to_bytes(secret_key, 'utf-8')
+ secret_key = to_bytes(secret_key, "utf-8")
self.secret_key = secret_key
self.new = new
@@ -178,10 +179,10 @@ class SecureCookie(ModificationTrackingDict):
)
def __repr__(self):
- return '<%s %s%s>' % (
+ return "<%s %s%s>" % (
self.__class__.__name__,
dict.__repr__(self),
- '*' if self.should_save else ''
+ "*" if self.should_save else "",
)
@property
@@ -201,7 +202,7 @@ class SecureCookie(ModificationTrackingDict):
if cls.serialization_method is not None:
value = cls.serialization_method.dumps(value)
if cls.quote_base64:
- value = b''.join(
+ value = b"".join(
base64.b64encode(to_bytes(value, "utf8")).splitlines()
).strip()
return value
@@ -236,21 +237,19 @@ class SecureCookie(ModificationTrackingDict):
:class:`datetime.datetime` object)
"""
if self.secret_key is None:
- raise RuntimeError('no secret key defined')
+ raise RuntimeError("no secret key defined")
if expires:
- self['_expires'] = _date_to_unix(expires)
+ self["_expires"] = _date_to_unix(expires)
result = []
mac = hmac(self.secret_key, None, self.hash_method)
for key, value in sorted(self.items()):
- result.append(('%s=%s' % (
- url_quote_plus(key),
- self.quote(value).decode('ascii')
- )).encode('ascii'))
- mac.update(b'|' + result[-1])
- return b'?'.join([
- base64.b64encode(mac.digest()).strip(),
- b'&'.join(result)
- ])
+ result.append(
+ (
+ "%s=%s" % (url_quote_plus(key), self.quote(value).decode("ascii"))
+ ).encode("ascii")
+ )
+ mac.update(b"|" + result[-1])
+ return b"?".join([base64.b64encode(mac.digest()).strip(), b"&".join(result)])
@classmethod
def unserialize(cls, string, secret_key):
@@ -261,24 +260,24 @@ class SecureCookie(ModificationTrackingDict):
:return: a new :class:`SecureCookie`.
"""
if isinstance(string, text_type):
- string = string.encode('utf-8', 'replace')
+ string = string.encode("utf-8", "replace")
if isinstance(secret_key, text_type):
- secret_key = secret_key.encode('utf-8', 'replace')
+ secret_key = secret_key.encode("utf-8", "replace")
try:
- base64_hash, data = string.split(b'?', 1)
+ base64_hash, data = string.split(b"?", 1)
except (ValueError, IndexError):
items = ()
else:
items = {}
mac = hmac(secret_key, None, cls.hash_method)
- for item in data.split(b'&'):
- mac.update(b'|' + item)
- if b'=' not in item:
+ for item in data.split(b"&"):
+ mac.update(b"|" + item)
+ if b"=" not in item:
items = None
break
- key, value = item.split(b'=', 1)
+ key, value = item.split(b"=", 1)
# try to make the key a string
- key = url_unquote_plus(key.decode('ascii'))
+ key = url_unquote_plus(key.decode("ascii"))
try:
key = to_native(key)
except UnicodeError:
@@ -298,17 +297,17 @@ class SecureCookie(ModificationTrackingDict):
except UnquoteError:
items = ()
else:
- if '_expires' in items:
- if time() > items['_expires']:
+ if "_expires" in items:
+ if time() > items["_expires"]:
items = ()
else:
- del items['_expires']
+ del items["_expires"]
else:
items = ()
return cls(items, secret_key, False)
@classmethod
- def load_cookie(cls, request, key='session', secret_key=None):
+ def load_cookie(cls, request, key="session", secret_key=None):
"""Loads a :class:`SecureCookie` from a cookie in request. If the
cookie is not set, a new :class:`SecureCookie` instanced is
returned.
@@ -325,9 +324,19 @@ class SecureCookie(ModificationTrackingDict):
return cls(secret_key=secret_key)
return cls.unserialize(data, secret_key)
- def save_cookie(self, response, key='session', expires=None,
- session_expires=None, max_age=None, path='/', domain=None,
- secure=None, httponly=False, force=False):
+ def save_cookie(
+ self,
+ response,
+ key="session",
+ expires=None,
+ session_expires=None,
+ max_age=None,
+ path="/",
+ domain=None,
+ secure=None,
+ httponly=False,
+ force=False,
+ ):
"""Saves the SecureCookie in a cookie on response object. All
parameters that are not described here are forwarded directly
to :meth:`~BaseResponse.set_cookie`.
@@ -341,6 +350,13 @@ class SecureCookie(ModificationTrackingDict):
"""
if force or self.should_save:
data = self.serialize(session_expires or expires)
- response.set_cookie(key, data, expires=expires, max_age=max_age,
- path=path, domain=domain, secure=secure,
- httponly=httponly)
+ response.set_cookie(
+ key,
+ data,
+ expires=expires,
+ max_age=max_age,
+ path=path,
+ domain=domain,
+ secure=secure,
+ httponly=httponly,
+ )
diff --git a/src/werkzeug/contrib/sessions.py b/src/werkzeug/contrib/sessions.py
index 302fd11c..866e827c 100644
--- a/src/werkzeug/contrib/sessions.py
+++ b/src/werkzeug/contrib/sessions.py
@@ -51,22 +51,26 @@ r"""
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-import re
import os
+import re
import tempfile
import warnings
+from hashlib import sha1
from os import path
-from time import time
+from pickle import dump
+from pickle import HIGHEST_PROTOCOL
+from pickle import load
from random import random
-from hashlib import sha1
-from pickle import dump, load, HIGHEST_PROTOCOL
+from time import time
-from werkzeug.datastructures import CallbackDict
-from werkzeug.utils import dump_cookie, parse_cookie
-from werkzeug.wsgi import ClosingIterator
-from werkzeug.posixemulation import rename
-from werkzeug._compat import PY2, text_type
-from werkzeug.filesystem import get_filesystem_encoding
+from .._compat import PY2
+from .._compat import text_type
+from ..datastructures import CallbackDict
+from ..filesystem import get_filesystem_encoding
+from ..posixemulation import rename
+from ..utils import dump_cookie
+from ..utils import parse_cookie
+from ..wsgi import ClosingIterator
warnings.warn(
"'werkzeug.contrib.sessions' is deprecated as of version 0.15 and"
@@ -76,31 +80,28 @@ warnings.warn(
stacklevel=2,
)
-_sha1_re = re.compile(r'^[a-f0-9]{40}$')
+_sha1_re = re.compile(r"^[a-f0-9]{40}$")
def _urandom():
- if hasattr(os, 'urandom'):
+ if hasattr(os, "urandom"):
return os.urandom(30)
- return text_type(random()).encode('ascii')
+ return text_type(random()).encode("ascii")
def generate_key(salt=None):
if salt is None:
- salt = repr(salt).encode('ascii')
- return sha1(b''.join([
- salt,
- str(time()).encode('ascii'),
- _urandom()
- ])).hexdigest()
+ salt = repr(salt).encode("ascii")
+ return sha1(b"".join([salt, str(time()).encode("ascii"), _urandom()])).hexdigest()
class ModificationTrackingDict(CallbackDict):
- __slots__ = ('modified',)
+ __slots__ = ("modified",)
def __init__(self, *args, **kwargs):
def on_update(self):
self.modified = True
+
self.modified = False
CallbackDict.__init__(self, on_update=on_update)
dict.update(self, *args, **kwargs)
@@ -120,12 +121,12 @@ class ModificationTrackingDict(CallbackDict):
class Session(ModificationTrackingDict):
-
"""Subclass of a dict that keeps track of direct object changes. Changes
in mutable structures are not tracked, for those you have to set
`modified` to `True` by hand.
"""
- __slots__ = ModificationTrackingDict.__slots__ + ('sid', 'new')
+
+ __slots__ = ModificationTrackingDict.__slots__ + ("sid", "new")
def __init__(self, data, sid, new=False):
ModificationTrackingDict.__init__(self, data)
@@ -133,10 +134,10 @@ class Session(ModificationTrackingDict):
self.new = new
def __repr__(self):
- return '<%s %s%s>' % (
+ return "<%s %s%s>" % (
self.__class__.__name__,
dict.__repr__(self),
- '*' if self.should_save else ''
+ "*" if self.should_save else "",
)
@property
@@ -151,7 +152,6 @@ class Session(ModificationTrackingDict):
class SessionStore(object):
-
"""Baseclass for all session stores. The Werkzeug contrib module does not
implement any useful stores besides the filesystem store, application
developers are encouraged to create their own stores.
@@ -197,11 +197,10 @@ class SessionStore(object):
#: used for temporary files by the filesystem session store
-_fs_transaction_suffix = '.__wz_sess'
+_fs_transaction_suffix = ".__wz_sess"
class FilesystemSessionStore(SessionStore):
-
"""Simple example session store that saves sessions on the filesystem.
This store works best on POSIX systems and Windows Vista / Windows
Server 2008 and newer.
@@ -223,17 +222,23 @@ class FilesystemSessionStore(SessionStore):
not yet saved.
"""
- def __init__(self, path=None, filename_template='werkzeug_%s.sess',
- session_class=None, renew_missing=False, mode=0o644):
+ def __init__(
+ self,
+ path=None,
+ filename_template="werkzeug_%s.sess",
+ session_class=None,
+ renew_missing=False,
+ mode=0o644,
+ ):
SessionStore.__init__(self, session_class)
if path is None:
path = tempfile.gettempdir()
self.path = path
if isinstance(filename_template, text_type) and PY2:
- filename_template = filename_template.encode(
- get_filesystem_encoding())
- assert not filename_template.endswith(_fs_transaction_suffix), \
- 'filename templates may not end with %s' % _fs_transaction_suffix
+ filename_template = filename_template.encode(get_filesystem_encoding())
+ assert not filename_template.endswith(_fs_transaction_suffix), (
+ "filename templates may not end with %s" % _fs_transaction_suffix
+ )
self.filename_template = filename_template
self.renew_missing = renew_missing
self.mode = mode
@@ -248,9 +253,8 @@ class FilesystemSessionStore(SessionStore):
def save(self, session):
fn = self.get_session_filename(session.sid)
- fd, tmp = tempfile.mkstemp(suffix=_fs_transaction_suffix,
- dir=self.path)
- f = os.fdopen(fd, 'wb')
+ fd, tmp = tempfile.mkstemp(suffix=_fs_transaction_suffix, dir=self.path)
+ f = os.fdopen(fd, "wb")
try:
dump(dict(session), f, HIGHEST_PROTOCOL)
finally:
@@ -272,7 +276,7 @@ class FilesystemSessionStore(SessionStore):
if not self.is_valid_key(sid):
return self.new()
try:
- f = open(self.get_session_filename(sid), 'rb')
+ f = open(self.get_session_filename(sid), "rb")
except IOError:
if self.renew_missing:
return self.new()
@@ -292,9 +296,10 @@ class FilesystemSessionStore(SessionStore):
.. versionadded:: 0.6
"""
- before, after = self.filename_template.split('%s', 1)
- filename_re = re.compile(r'%s(.{5,})%s$' % (re.escape(before),
- re.escape(after)))
+ before, after = self.filename_template.split("%s", 1)
+ filename_re = re.compile(
+ r"%s(.{5,})%s$" % (re.escape(before), re.escape(after))
+ )
result = []
for filename in os.listdir(self.path):
#: this is a session that is still being saved.
@@ -307,7 +312,6 @@ class FilesystemSessionStore(SessionStore):
class SessionMiddleware(object):
-
"""A simple middleware that puts the session object of a store provided
into the WSGI environ. It automatically sets cookies and restores
sessions.
@@ -323,11 +327,20 @@ class SessionMiddleware(object):
compatibility.
"""
- def __init__(self, app, store, cookie_name='session_id',
- cookie_age=None, cookie_expires=None, cookie_path='/',
- cookie_domain=None, cookie_secure=None,
- cookie_httponly=False, cookie_samesite='Lax',
- environ_key='werkzeug.session'):
+ def __init__(
+ self,
+ app,
+ store,
+ cookie_name="session_id",
+ cookie_age=None,
+ cookie_expires=None,
+ cookie_path="/",
+ cookie_domain=None,
+ cookie_secure=None,
+ cookie_httponly=False,
+ cookie_samesite="Lax",
+ environ_key="werkzeug.session",
+ ):
self.app = app
self.store = store
self.cookie_name = cookie_name
@@ -341,7 +354,7 @@ class SessionMiddleware(object):
self.environ_key = environ_key
def __call__(self, environ, start_response):
- cookie = parse_cookie(environ.get('HTTP_COOKIE', ''))
+ cookie = parse_cookie(environ.get("HTTP_COOKIE", ""))
sid = cookie.get(self.cookie_name, None)
if sid is None:
session = self.store.new()
@@ -352,12 +365,25 @@ class SessionMiddleware(object):
def injecting_start_response(status, headers, exc_info=None):
if session.should_save:
self.store.save(session)
- headers.append(('Set-Cookie', dump_cookie(self.cookie_name,
- session.sid, self.cookie_age,
- self.cookie_expires, self.cookie_path,
- self.cookie_domain, self.cookie_secure,
- self.cookie_httponly,
- samesite=self.cookie_samesite)))
+ headers.append(
+ (
+ "Set-Cookie",
+ dump_cookie(
+ self.cookie_name,
+ session.sid,
+ self.cookie_age,
+ self.cookie_expires,
+ self.cookie_path,
+ self.cookie_domain,
+ self.cookie_secure,
+ self.cookie_httponly,
+ samesite=self.cookie_samesite,
+ ),
+ )
+ )
return start_response(status, headers, exc_info)
- return ClosingIterator(self.app(environ, injecting_start_response),
- lambda: self.store.save_if_modified(session))
+
+ return ClosingIterator(
+ self.app(environ, injecting_start_response),
+ lambda: self.store.save_if_modified(session),
+ )
diff --git a/src/werkzeug/contrib/wrappers.py b/src/werkzeug/contrib/wrappers.py
index fa3bc565..49b82a71 100644
--- a/src/werkzeug/contrib/wrappers.py
+++ b/src/werkzeug/contrib/wrappers.py
@@ -23,11 +23,12 @@
import codecs
import warnings
-from werkzeug.exceptions import BadRequest
-from werkzeug.utils import cached_property
-from werkzeug.http import dump_options_header, parse_options_header
-from werkzeug._compat import wsgi_decoding_dance
-from werkzeug.wrappers.json import JSONMixin as _JSONMixin
+from .._compat import wsgi_decoding_dance
+from ..exceptions import BadRequest
+from ..http import dump_options_header
+from ..http import parse_options_header
+from ..utils import cached_property
+from ..wrappers.json import JSONMixin as _JSONMixin
def is_known_charset(charset):
@@ -87,8 +88,8 @@ class ProtobufRequestMixin(object):
DeprecationWarning,
stacklevel=2,
)
- if 'protobuf' not in self.environ.get('CONTENT_TYPE', ''):
- raise BadRequest('Not a Protobuf request')
+ if "protobuf" not in self.environ.get("CONTENT_TYPE", ""):
+ raise BadRequest("Not a Protobuf request")
obj = proto_type()
try:
@@ -108,7 +109,8 @@ class RoutingArgsRequestMixin(object):
"""This request mixin adds support for the wsgiorg routing args
`specification`_.
- .. _specification: https://wsgi.readthedocs.io/en/latest/specifications/routing_args.html
+ .. _specification: https://wsgi.readthedocs.io/en/latest/
+ specifications/routing_args.html
.. deprecated:: 0.15
This mixin will be removed in version 1.0.
@@ -122,7 +124,7 @@ class RoutingArgsRequestMixin(object):
DeprecationWarning,
stacklevel=2,
)
- return self.environ.get('wsgiorg.routing_args', (()))[0]
+ return self.environ.get("wsgiorg.routing_args", (()))[0]
def _set_routing_args(self, value):
warnings.warn(
@@ -133,13 +135,19 @@ class RoutingArgsRequestMixin(object):
stacklevel=2,
)
if self.shallow:
- raise RuntimeError('A shallow request tried to modify the WSGI '
- 'environment. If you really want to do that, '
- 'set `shallow` to False.')
- self.environ['wsgiorg.routing_args'] = (value, self.routing_vars)
-
- routing_args = property(_get_routing_args, _set_routing_args, doc='''
- The positional URL arguments as `tuple`.''')
+ raise RuntimeError(
+ "A shallow request tried to modify the WSGI "
+ "environment. If you really want to do that, "
+ "set `shallow` to False."
+ )
+ self.environ["wsgiorg.routing_args"] = (value, self.routing_vars)
+
+ routing_args = property(
+ _get_routing_args,
+ _set_routing_args,
+ doc="""
+ The positional URL arguments as `tuple`.""",
+ )
del _get_routing_args, _set_routing_args
def _get_routing_vars(self):
@@ -150,7 +158,7 @@ class RoutingArgsRequestMixin(object):
DeprecationWarning,
stacklevel=2,
)
- rv = self.environ.get('wsgiorg.routing_args')
+ rv = self.environ.get("wsgiorg.routing_args")
if rv is not None:
return rv[1]
rv = {}
@@ -167,13 +175,19 @@ class RoutingArgsRequestMixin(object):
stacklevel=2,
)
if self.shallow:
- raise RuntimeError('A shallow request tried to modify the WSGI '
- 'environment. If you really want to do that, '
- 'set `shallow` to False.')
- self.environ['wsgiorg.routing_args'] = (self.routing_args, value)
-
- routing_vars = property(_get_routing_vars, _set_routing_vars, doc='''
- The keyword URL arguments as `dict`.''')
+ raise RuntimeError(
+ "A shallow request tried to modify the WSGI "
+ "environment. If you really want to do that, "
+ "set `shallow` to False."
+ )
+ self.environ["wsgiorg.routing_args"] = (self.routing_args, value)
+
+ routing_vars = property(
+ _get_routing_vars,
+ _set_routing_vars,
+ doc="""
+ The keyword URL arguments as `dict`.""",
+ )
del _get_routing_vars, _set_routing_vars
@@ -216,9 +230,10 @@ class ReverseSlashBehaviorRequestMixin(object):
DeprecationWarning,
stacklevel=2,
)
- path = wsgi_decoding_dance(self.environ.get('PATH_INFO') or '',
- self.charset, self.encoding_errors)
- return path.lstrip('/')
+ path = wsgi_decoding_dance(
+ self.environ.get("PATH_INFO") or "", self.charset, self.encoding_errors
+ )
+ return path.lstrip("/")
@cached_property
def script_root(self):
@@ -230,9 +245,10 @@ class ReverseSlashBehaviorRequestMixin(object):
DeprecationWarning,
stacklevel=2,
)
- path = wsgi_decoding_dance(self.environ.get('SCRIPT_NAME') or '',
- self.charset, self.encoding_errors)
- return path.rstrip('/') + '/'
+ path = wsgi_decoding_dance(
+ self.environ.get("SCRIPT_NAME") or "", self.charset, self.encoding_errors
+ )
+ return path.rstrip("/") + "/"
class DynamicCharsetRequestMixin(object):
@@ -268,7 +284,7 @@ class DynamicCharsetRequestMixin(object):
#: is latin1 which is what HTTP specifies as default charset.
#: You may however want to set this to utf-8 to better support
#: browsers that do not transmit a charset for incoming data.
- default_charset = 'latin1'
+ default_charset = "latin1"
def unknown_charset(self, charset):
"""Called if a charset was provided but is not supported by
@@ -279,7 +295,7 @@ class DynamicCharsetRequestMixin(object):
:param charset: the charset that was not found.
:return: the replacement charset.
"""
- return 'latin1'
+ return "latin1"
@cached_property
def charset(self):
@@ -291,10 +307,10 @@ class DynamicCharsetRequestMixin(object):
DeprecationWarning,
stacklevel=2,
)
- header = self.environ.get('CONTENT_TYPE')
+ header = self.environ.get("CONTENT_TYPE")
if header:
ct, options = parse_options_header(header)
- charset = options.get('charset')
+ charset = options.get("charset")
if charset:
if is_known_charset(charset):
return charset
@@ -327,7 +343,7 @@ class DynamicCharsetResponseMixin(object):
"""
#: the default charset.
- default_charset = 'utf-8'
+ default_charset = "utf-8"
def _get_charset(self):
warnings.warn(
@@ -337,9 +353,9 @@ class DynamicCharsetResponseMixin(object):
DeprecationWarning,
stacklevel=2,
)
- header = self.headers.get('content-type')
+ header = self.headers.get("content-type")
if header:
- charset = parse_options_header(header)[1].get('charset')
+ charset = parse_options_header(header)[1].get("charset")
if charset:
return charset
return self.default_charset
@@ -352,15 +368,18 @@ class DynamicCharsetResponseMixin(object):
DeprecationWarning,
stacklevel=2,
)
- header = self.headers.get('content-type')
+ header = self.headers.get("content-type")
ct, options = parse_options_header(header)
if not ct:
- raise TypeError('Cannot set charset if Content-Type '
- 'header is missing.')
- options['charset'] = charset
- self.headers['Content-Type'] = dump_options_header(ct, options)
-
- charset = property(_get_charset, _set_charset, doc="""
+ raise TypeError("Cannot set charset if Content-Type header is missing.")
+ options["charset"] = charset
+ self.headers["Content-Type"] = dump_options_header(ct, options)
+
+ charset = property(
+ _get_charset,
+ _set_charset,
+ doc="""
The charset for the response. It's stored inside the
- Content-Type header as a parameter.""")
+ Content-Type header as a parameter.""",
+ )
del _get_charset, _set_charset
diff --git a/src/werkzeug/datastructures.py b/src/werkzeug/datastructures.py
index a39397a2..9643db96 100644
--- a/src/werkzeug/datastructures.py
+++ b/src/werkzeug/datastructures.py
@@ -8,24 +8,32 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-import re
import codecs
import mimetypes
+import re
from copy import deepcopy
from itertools import repeat
-from werkzeug._internal import _missing
-from werkzeug._compat import BytesIO, collections_abc, iterkeys, itervalues, \
- iteritems, iterlists, PY2, text_type, integer_types, string_types, \
- make_literal_wrapper, to_native
-from werkzeug.filesystem import get_filesystem_encoding
-
+from ._compat import BytesIO
+from ._compat import collections_abc
+from ._compat import integer_types
+from ._compat import iteritems
+from ._compat import iterkeys
+from ._compat import iterlists
+from ._compat import itervalues
+from ._compat import make_literal_wrapper
+from ._compat import PY2
+from ._compat import string_types
+from ._compat import text_type
+from ._compat import to_native
+from ._internal import _missing
+from .filesystem import get_filesystem_encoding
-_locale_delim_re = re.compile(r'[_-]')
+_locale_delim_re = re.compile(r"[_-]")
def is_immutable(self):
- raise TypeError('%r objects are immutable' % self.__class__.__name__)
+ raise TypeError("%r objects are immutable" % self.__class__.__name__)
def iter_multi_items(mapping):
@@ -52,18 +60,28 @@ def native_itermethods(names):
return lambda x: x
def setviewmethod(cls, name):
- viewmethod_name = 'view%s' % name
- viewmethod = lambda self, *a, **kw: ViewItems(self, name, 'view_%s' % name, *a, **kw)
- viewmethod.__doc__ = \
- '"""`%s()` object providing a view on %s"""' % (viewmethod_name, name)
+ viewmethod_name = "view%s" % name
+ repr_name = "view_%s" % name
+
+ def viewmethod(self, *a, **kw):
+ return ViewItems(self, name, repr_name, *a, **kw)
+
+ viewmethod.__name__ = viewmethod_name
+ viewmethod.__doc__ = "`%s()` object providing a view on %s" % (
+ viewmethod_name,
+ name,
+ )
setattr(cls, viewmethod_name, viewmethod)
def setitermethod(cls, name):
itermethod = getattr(cls, name)
- setattr(cls, 'iter%s' % name, itermethod)
- listmethod = lambda self, *a, **kw: list(itermethod(self, *a, **kw))
- listmethod.__doc__ = \
- 'Like :py:meth:`iter%s`, but returns a list.' % name
+ setattr(cls, "iter%s" % name, itermethod)
+
+ def listmethod(self, *a, **kw):
+ return list(itermethod(self, *a, **kw))
+
+ listmethod.__name__ = name
+ listmethod.__doc__ = "Like :py:meth:`iter%s`, but returns a list." % name
setattr(cls, name, listmethod)
def wrap(cls):
@@ -71,11 +89,11 @@ def native_itermethods(names):
setitermethod(cls, name)
setviewmethod(cls, name)
return cls
+
return wrap
class ImmutableListMixin(object):
-
"""Makes a :class:`list` immutable.
.. versionadded:: 0.5
@@ -99,6 +117,7 @@ class ImmutableListMixin(object):
def __iadd__(self, other):
is_immutable(self)
+
__imul__ = __iadd__
def __setitem__(self, key, value):
@@ -106,6 +125,7 @@ class ImmutableListMixin(object):
def append(self, item):
is_immutable(self)
+
remove = append
def extend(self, iterable):
@@ -125,7 +145,6 @@ class ImmutableListMixin(object):
class ImmutableList(ImmutableListMixin, list):
-
"""An immutable :class:`list`.
.. versionadded:: 0.5
@@ -134,20 +153,17 @@ class ImmutableList(ImmutableListMixin, list):
"""
def __repr__(self):
- return '%s(%s)' % (
- self.__class__.__name__,
- list.__repr__(self),
- )
+ return "%s(%s)" % (self.__class__.__name__, list.__repr__(self))
class ImmutableDictMixin(object):
-
"""Makes a :class:`dict` immutable.
.. versionadded:: 0.5
:private:
"""
+
_hash_cache = None
@classmethod
@@ -191,7 +207,6 @@ class ImmutableDictMixin(object):
class ImmutableMultiDictMixin(ImmutableDictMixin):
-
"""Makes a :class:`MultiDict` immutable.
.. versionadded:: 0.5
@@ -222,7 +237,6 @@ class ImmutableMultiDictMixin(ImmutableDictMixin):
class UpdateDictMixin(object):
-
"""Makes dicts call `self.on_update` on modifications.
.. versionadded:: 0.5
@@ -232,12 +246,13 @@ class UpdateDictMixin(object):
on_update = None
- def calls_update(name):
+ def calls_update(name): # noqa: B902
def oncall(self, *args, **kw):
rv = getattr(super(UpdateDictMixin, self), name)(*args, **kw)
if self.on_update is not None:
self.on_update(self)
return rv
+
oncall.__name__ = name
return oncall
@@ -258,16 +273,15 @@ class UpdateDictMixin(object):
self.on_update(self)
return rv
- __setitem__ = calls_update('__setitem__')
- __delitem__ = calls_update('__delitem__')
- clear = calls_update('clear')
- popitem = calls_update('popitem')
- update = calls_update('update')
+ __setitem__ = calls_update("__setitem__")
+ __delitem__ = calls_update("__delitem__")
+ clear = calls_update("clear")
+ popitem = calls_update("popitem")
+ update = calls_update("update")
del calls_update
class TypeConversionDict(dict):
-
"""Works like a regular dict but the :meth:`get` method can perform
type conversions. :class:`MultiDict` and :class:`CombinedMultiDict`
are subclasses of this class and provide the same feature.
@@ -309,7 +323,6 @@ class TypeConversionDict(dict):
class ImmutableTypeConversionDict(ImmutableDictMixin, TypeConversionDict):
-
"""Works like a :class:`TypeConversionDict` but does not support
modifications.
@@ -328,7 +341,6 @@ class ImmutableTypeConversionDict(ImmutableDictMixin, TypeConversionDict):
class ViewItems(object):
-
def __init__(self, multi_dict, method, repr_name, *a, **kw):
self.__multi_dict = multi_dict
self.__method = method
@@ -340,15 +352,14 @@ class ViewItems(object):
return getattr(self.__multi_dict, self.__method)(*self.__a, **self.__kw)
def __repr__(self):
- return '%s(%r)' % (self.__repr_name, list(self.__get_items()))
+ return "%s(%r)" % (self.__repr_name, list(self.__get_items()))
def __iter__(self):
return iter(self.__get_items())
-@native_itermethods(['keys', 'values', 'items', 'lists', 'listvalues'])
+@native_itermethods(["keys", "values", "items", "lists", "listvalues"])
class MultiDict(TypeConversionDict):
-
"""A :class:`MultiDict` is a dictionary subclass customized to deal with
multiple values for the same key which is for example used by the parsing
functions in the wrappers. This is necessary because some HTML form
@@ -678,17 +689,17 @@ class MultiDict(TypeConversionDict):
return self.deepcopy(memo=memo)
def __repr__(self):
- return '%s(%r)' % (self.__class__.__name__, list(iteritems(self, multi=True)))
+ return "%s(%r)" % (self.__class__.__name__, list(iteritems(self, multi=True)))
class _omd_bucket(object):
-
"""Wraps values in the :class:`OrderedMultiDict`. This makes it
possible to keep an order over multiple different keys. It requires
a lot of extra memory and slows down access a lot, but makes it
possible to access elements in O(1) and iterate in O(n).
"""
- __slots__ = ('prev', 'key', 'value', 'next')
+
+ __slots__ = ("prev", "key", "value", "next")
def __init__(self, omd, key, value):
self.prev = omd._last_bucket
@@ -713,9 +724,8 @@ class _omd_bucket(object):
omd._last_bucket = self.prev
-@native_itermethods(['keys', 'values', 'items', 'lists', 'listvalues'])
+@native_itermethods(["keys", "values", "items", "lists", "listvalues"])
class OrderedMultiDict(MultiDict):
-
"""Works like a regular :class:`MultiDict` but preserves the
order of the fields. To convert the ordered multi dict into a
list you can use the :meth:`items` method and pass it ``multi=True``.
@@ -822,7 +832,7 @@ class OrderedMultiDict(MultiDict):
ptr = ptr.next
def listvalues(self):
- for key, values in iterlists(self):
+ for _key, values in iterlists(self):
yield values
def add(self, key, value):
@@ -849,8 +859,7 @@ class OrderedMultiDict(MultiDict):
self.add(key, value)
def setlistdefault(self, key, default_list=None):
- raise TypeError('setlistdefault is unsupported for '
- 'ordered multi dicts')
+ raise TypeError("setlistdefault is unsupported for ordered multi dicts")
def update(self, mapping):
for key, value in iter_multi_items(mapping):
@@ -893,21 +902,21 @@ class OrderedMultiDict(MultiDict):
def _options_header_vkw(value, kw):
- return dump_options_header(value, dict((k.replace('_', '-'), v)
- for k, v in kw.items()))
+ return dump_options_header(
+ value, dict((k.replace("_", "-"), v) for k, v in kw.items())
+ )
def _unicodify_header_value(value):
if isinstance(value, bytes):
- value = value.decode('latin-1')
+ value = value.decode("latin-1")
if not isinstance(value, text_type):
value = text_type(value)
return value
-@native_itermethods(['keys', 'values', 'items'])
+@native_itermethods(["keys", "values", "items"])
class Headers(object):
-
"""An object that stores some headers. It has a dict-like interface
but is ordered and can store the same keys multiple times.
@@ -968,8 +977,7 @@ class Headers(object):
raise exceptions.BadRequestKeyError(key)
def __eq__(self, other):
- return other.__class__ is self.__class__ and \
- set(other._list) == set(self._list)
+ return other.__class__ is self.__class__ and set(other._list) == set(self._list)
__hash__ = None
@@ -1007,7 +1015,7 @@ class Headers(object):
except KeyError:
return default
if as_bytes:
- rv = rv.encode('latin1')
+ rv = rv.encode("latin1")
if type is None:
return rv
try:
@@ -1036,7 +1044,7 @@ class Headers(object):
for k, v in self:
if k.lower() == ikey:
if as_bytes:
- v = v.encode('latin1')
+ v = v.encode("latin1")
if type is not None:
try:
v = type(v)
@@ -1168,10 +1176,12 @@ class Headers(object):
def _validate_value(self, value):
if not isinstance(value, text_type):
- raise TypeError('Value should be unicode.')
- if u'\n' in value or u'\r' in value:
- raise ValueError('Detected newline in header value. This is '
- 'a potential security problem')
+ raise TypeError("Value should be unicode.")
+ if u"\n" in value or u"\r" in value:
+ raise ValueError(
+ "Detected newline in header value. This is "
+ "a potential security problem"
+ )
def add_header(self, _key, _value, **_kw):
"""Add a new header tuple to the list.
@@ -1210,7 +1220,7 @@ class Headers(object):
return
listiter = iter(self._list)
ikey = _key.lower()
- for idx, (old_key, old_value) in enumerate(listiter):
+ for idx, (old_key, _old_value) in enumerate(listiter):
if old_key.lower() == ikey:
# replace first ocurrence
self._list[idx] = (_key, _value)
@@ -1218,7 +1228,7 @@ class Headers(object):
else:
self._list.append((_key, _value))
return
- self._list[idx + 1:] = [t for t in listiter if t[0].lower() != ikey]
+ self._list[idx + 1 :] = [t for t in listiter if t[0].lower() != ikey]
def setdefault(self, key, default):
"""Returns the value for the key if it is in the dict, otherwise it
@@ -1250,17 +1260,18 @@ class Headers(object):
else:
self.set(key, value)
- def to_list(self, charset='iso-8859-1'):
+ def to_list(self, charset="iso-8859-1"):
"""Convert the headers into a list suitable for WSGI.
.. deprecated:: 0.9
"""
from warnings import warn
+
warn(
"'to_list' deprecated as of version 0.9 and will be removed"
" in version 1.0. Use 'to_wsgi_list' instead.",
DeprecationWarning,
- stacklevel=2
+ stacklevel=2,
)
return self.to_wsgi_list()
@@ -1273,7 +1284,7 @@ class Headers(object):
:return: list
"""
if PY2:
- return [(to_native(k), v.encode('latin1')) for k, v in self]
+ return [(to_native(k), v.encode("latin1")) for k, v in self]
return list(self)
def copy(self):
@@ -1286,19 +1297,15 @@ class Headers(object):
"""Returns formatted headers suitable for HTTP transmission."""
strs = []
for key, value in self.to_wsgi_list():
- strs.append('%s: %s' % (key, value))
- strs.append('\r\n')
- return '\r\n'.join(strs)
+ strs.append("%s: %s" % (key, value))
+ strs.append("\r\n")
+ return "\r\n".join(strs)
def __repr__(self):
- return '%s(%r)' % (
- self.__class__.__name__,
- list(self)
- )
+ return "%s(%r)" % (self.__class__.__name__, list(self))
class ImmutableHeadersMixin(object):
-
"""Makes a :class:`Headers` immutable. We do not mark them as
hashable though since the only usecase for this datastructure
in Werkzeug is a view on a mutable structure.
@@ -1313,10 +1320,12 @@ class ImmutableHeadersMixin(object):
def __setitem__(self, key, value):
is_immutable(self)
+
set = __setitem__
def add(self, item):
is_immutable(self)
+
remove = add_header = add
def extend(self, iterable):
@@ -1336,7 +1345,6 @@ class ImmutableHeadersMixin(object):
class EnvironHeaders(ImmutableHeadersMixin, Headers):
-
"""Read only version of the headers from a WSGI environment. This
provides the same interface as `Headers` and is constructed from
a WSGI environment.
@@ -1360,10 +1368,10 @@ class EnvironHeaders(ImmutableHeadersMixin, Headers):
# used because get() calls it.
if not isinstance(key, string_types):
raise KeyError(key)
- key = key.upper().replace('-', '_')
- if key in ('CONTENT_TYPE', 'CONTENT_LENGTH'):
+ key = key.upper().replace("-", "_")
+ if key in ("CONTENT_TYPE", "CONTENT_LENGTH"):
return _unicodify_header_value(self.environ[key])
- return _unicodify_header_value(self.environ['HTTP_' + key])
+ return _unicodify_header_value(self.environ["HTTP_" + key])
def __len__(self):
# the iter is necessary because otherwise list calls our
@@ -1372,21 +1380,23 @@ class EnvironHeaders(ImmutableHeadersMixin, Headers):
def __iter__(self):
for key, value in iteritems(self.environ):
- if key.startswith('HTTP_') and key not in \
- ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'):
- yield (key[5:].replace('_', '-').title(),
- _unicodify_header_value(value))
- elif key in ('CONTENT_TYPE', 'CONTENT_LENGTH') and value:
- yield (key.replace('_', '-').title(),
- _unicodify_header_value(value))
+ if key.startswith("HTTP_") and key not in (
+ "HTTP_CONTENT_TYPE",
+ "HTTP_CONTENT_LENGTH",
+ ):
+ yield (
+ key[5:].replace("_", "-").title(),
+ _unicodify_header_value(value),
+ )
+ elif key in ("CONTENT_TYPE", "CONTENT_LENGTH") and value:
+ yield (key.replace("_", "-").title(), _unicodify_header_value(value))
def copy(self):
- raise TypeError('cannot create %r copies' % self.__class__.__name__)
+ raise TypeError("cannot create %r copies" % self.__class__.__name__)
-@native_itermethods(['keys', 'values', 'items', 'lists', 'listvalues'])
+@native_itermethods(["keys", "values", "items", "lists", "listvalues"])
class CombinedMultiDict(ImmutableMultiDictMixin, MultiDict):
-
"""A read only :class:`MultiDict` that you can pass multiple :class:`MultiDict`
instances as sequence and it will combine the return values of all wrapped
dicts:
@@ -1417,8 +1427,7 @@ class CombinedMultiDict(ImmutableMultiDictMixin, MultiDict):
@classmethod
def fromkeys(cls):
- raise TypeError('cannot create %r instances by fromkeys' %
- cls.__name__)
+ raise TypeError("cannot create %r instances by fromkeys" % cls.__name__)
def __getitem__(self, key):
for d in self.dicts:
@@ -1471,7 +1480,7 @@ class CombinedMultiDict(ImmutableMultiDictMixin, MultiDict):
yield key, value
def values(self):
- for key, value in iteritems(self):
+ for _key, value in iteritems(self):
yield value
def lists(self):
@@ -1523,11 +1532,10 @@ class CombinedMultiDict(ImmutableMultiDictMixin, MultiDict):
has_key = __contains__
def __repr__(self):
- return '%s(%r)' % (self.__class__.__name__, self.dicts)
+ return "%s(%r)" % (self.__class__.__name__, self.dicts)
class FileMultiDict(MultiDict):
-
"""A special :class:`MultiDict` that has convenience methods to add
files to it. This is used for :class:`EnvironBuilder` and generally
useful for unittesting.
@@ -1550,27 +1558,24 @@ class FileMultiDict(MultiDict):
if isinstance(file, string_types):
if filename is None:
filename = file
- file = open(file, 'rb')
+ file = open(file, "rb")
if filename and content_type is None:
- content_type = mimetypes.guess_type(filename)[0] or \
- 'application/octet-stream'
+ content_type = (
+ mimetypes.guess_type(filename)[0] or "application/octet-stream"
+ )
value = FileStorage(file, filename, name, content_type)
self.add(name, value)
class ImmutableDict(ImmutableDictMixin, dict):
-
"""An immutable :class:`dict`.
.. versionadded:: 0.5
"""
def __repr__(self):
- return '%s(%s)' % (
- self.__class__.__name__,
- dict.__repr__(self),
- )
+ return "%s(%s)" % (self.__class__.__name__, dict.__repr__(self))
def copy(self):
"""Return a shallow mutable copy of this object. Keep in mind that
@@ -1584,7 +1589,6 @@ class ImmutableDict(ImmutableDictMixin, dict):
class ImmutableMultiDict(ImmutableMultiDictMixin, MultiDict):
-
"""An immutable :class:`MultiDict`.
.. versionadded:: 0.5
@@ -1602,7 +1606,6 @@ class ImmutableMultiDict(ImmutableMultiDictMixin, MultiDict):
class ImmutableOrderedMultiDict(ImmutableMultiDictMixin, OrderedMultiDict):
-
"""An immutable :class:`OrderedMultiDict`.
.. versionadded:: 0.6
@@ -1622,9 +1625,8 @@ class ImmutableOrderedMultiDict(ImmutableMultiDictMixin, OrderedMultiDict):
return self
-@native_itermethods(['values'])
+@native_itermethods(["values"])
class Accept(ImmutableList):
-
"""An :class:`Accept` object is just a list subclass for lists of
``(value, quality)`` tuples. It is automatically sorted by specificity
and quality.
@@ -1663,17 +1665,20 @@ class Accept(ImmutableList):
list.__init__(self, values)
else:
self.provided = True
- values = sorted(values, key=lambda x: (self._specificity(x[0]), x[1], x[0]),
- reverse=True)
+ values = sorted(
+ values,
+ key=lambda x: (self._specificity(x[0]), x[1], x[0]),
+ reverse=True,
+ )
list.__init__(self, values)
def _specificity(self, value):
"""Returns a tuple describing the value's specificity."""
- return value != '*',
+ return (value != "*",)
def _value_matches(self, value, item):
"""Check if a value matches a given accept item."""
- return item == '*' or item.lower() == value.lower()
+ return item == "*" or item.lower() == value.lower()
def __getitem__(self, key):
"""Besides index lookup (getting item n) you can also pass it a string
@@ -1697,15 +1702,15 @@ class Accept(ImmutableList):
return 0
def __contains__(self, value):
- for item, quality in self:
+ for item, _quality in self:
if self._value_matches(value, item):
return True
return False
def __repr__(self):
- return '%s([%s])' % (
+ return "%s([%s])" % (
self.__class__.__name__,
- ', '.join('(%r, %s)' % (x, y) for x, y in self)
+ ", ".join("(%r, %s)" % (x, y) for x, y in self),
)
def index(self, key):
@@ -1718,7 +1723,7 @@ class Accept(ImmutableList):
with the list API.
"""
if isinstance(key, string_types):
- for idx, (item, quality) in enumerate(self):
+ for idx, (item, _quality) in enumerate(self):
if self._value_matches(key, item):
return idx
raise ValueError(key)
@@ -1744,9 +1749,9 @@ class Accept(ImmutableList):
result = []
for value, quality in self:
if quality != 1:
- value = '%s;q=%s' % (value, quality)
+ value = "%s;q=%s" % (value, quality)
result.append(value)
- return ','.join(result)
+ return ",".join(result)
def __str__(self):
return self.to_header()
@@ -1791,75 +1796,71 @@ class Accept(ImmutableList):
class MIMEAccept(Accept):
-
"""Like :class:`Accept` but with special methods and behavior for
mimetypes.
"""
def _specificity(self, value):
- return tuple(x != '*' for x in value.split('/', 1))
+ return tuple(x != "*" for x in value.split("/", 1))
def _value_matches(self, value, item):
def _normalize(x):
x = x.lower()
- return ('*', '*') if x == '*' else x.split('/', 1)
+ return ("*", "*") if x == "*" else x.split("/", 1)
# this is from the application which is trusted. to avoid developer
# frustration we actually check these for valid values
- if '/' not in value:
- raise ValueError('invalid mimetype %r' % value)
+ if "/" not in value:
+ raise ValueError("invalid mimetype %r" % value)
value_type, value_subtype = _normalize(value)
- if value_type == '*' and value_subtype != '*':
- raise ValueError('invalid mimetype %r' % value)
+ if value_type == "*" and value_subtype != "*":
+ raise ValueError("invalid mimetype %r" % value)
- if '/' not in item:
+ if "/" not in item:
return False
item_type, item_subtype = _normalize(item)
- if item_type == '*' and item_subtype != '*':
+ if item_type == "*" and item_subtype != "*":
return False
return (
- (item_type == item_subtype == '*'
- or value_type == value_subtype == '*')
- or (item_type == value_type and (item_subtype == '*'
- or value_subtype == '*'
- or item_subtype == value_subtype))
+ item_type == item_subtype == "*" or value_type == value_subtype == "*"
+ ) or (
+ item_type == value_type
+ and (
+ item_subtype == "*"
+ or value_subtype == "*"
+ or item_subtype == value_subtype
+ )
)
@property
def accept_html(self):
"""True if this object accepts HTML."""
return (
- 'text/html' in self
- or 'application/xhtml+xml' in self
- or self.accept_xhtml
+ "text/html" in self or "application/xhtml+xml" in self or self.accept_xhtml
)
@property
def accept_xhtml(self):
"""True if this object accepts XHTML."""
- return (
- 'application/xhtml+xml' in self
- or 'application/xml' in self
- )
+ return "application/xhtml+xml" in self or "application/xml" in self
@property
def accept_json(self):
"""True if this object accepts JSON."""
- return 'application/json' in self
+ return "application/json" in self
class LanguageAccept(Accept):
-
"""Like :class:`Accept` but with normalization for languages."""
def _value_matches(self, value, item):
def _normalize(language):
return _locale_delim_re.split(language.lower())
- return item == '*' or _normalize(value) == _normalize(item)
+ return item == "*" or _normalize(value) == _normalize(item)
-class CharsetAccept(Accept):
+class CharsetAccept(Accept):
"""Like :class:`Accept` but with normalization for charsets."""
def _value_matches(self, value, item):
@@ -1868,20 +1869,22 @@ class CharsetAccept(Accept):
return codecs.lookup(name).name
except LookupError:
return name.lower()
- return item == '*' or _normalize(value) == _normalize(item)
+
+ return item == "*" or _normalize(value) == _normalize(item)
def cache_property(key, empty, type):
"""Return a new property object for a cache header. Useful if you
want to add support for a cache extension in a subclass."""
- return property(lambda x: x._get_cache_value(key, empty, type),
- lambda x, v: x._set_cache_value(key, v, type),
- lambda x: x._del_cache_value(key),
- 'accessor for %r' % key)
+ return property(
+ lambda x: x._get_cache_value(key, empty, type),
+ lambda x, v: x._set_cache_value(key, v, type),
+ lambda x: x._del_cache_value(key),
+ "accessor for %r" % key,
+ )
class _CacheControl(UpdateDictMixin, dict):
-
"""Subclass of a dict that stores values for a Cache-Control header. It
has accessors for all the cache-control directives specified in RFC 2616.
The class does not differentiate between request and response directives.
@@ -1913,10 +1916,10 @@ class _CacheControl(UpdateDictMixin, dict):
no longer existing `CacheControl` class.
"""
- no_cache = cache_property('no-cache', '*', None)
- no_store = cache_property('no-store', None, bool)
- max_age = cache_property('max-age', -1, int)
- no_transform = cache_property('no-transform', None, None)
+ no_cache = cache_property("no-cache", "*", None)
+ no_store = cache_property("no-store", None, bool)
+ max_age = cache_property("max-age", -1, int)
+ no_transform = cache_property("no-transform", None, None)
def __init__(self, values=(), on_update=None):
dict.__init__(self, values or ())
@@ -1966,16 +1969,13 @@ class _CacheControl(UpdateDictMixin, dict):
return self.to_header()
def __repr__(self):
- return '<%s %s>' % (
+ return "<%s %s>" % (
self.__class__.__name__,
- " ".join(
- "%s=%r" % (k, v) for k, v in sorted(self.items())
- ),
+ " ".join("%s=%r" % (k, v) for k, v in sorted(self.items())),
)
class RequestCacheControl(ImmutableDictMixin, _CacheControl):
-
"""A cache control for requests. This is immutable and gives access
to all the request-relevant cache control headers.
@@ -1989,14 +1989,13 @@ class RequestCacheControl(ImmutableDictMixin, _CacheControl):
both for request and response.
"""
- max_stale = cache_property('max-stale', '*', int)
- min_fresh = cache_property('min-fresh', '*', int)
- no_transform = cache_property('no-transform', None, None)
- only_if_cached = cache_property('only-if-cached', None, bool)
+ max_stale = cache_property("max-stale", "*", int)
+ min_fresh = cache_property("min-fresh", "*", int)
+ no_transform = cache_property("no-transform", None, None)
+ only_if_cached = cache_property("only-if-cached", None, bool)
class ResponseCacheControl(_CacheControl):
-
"""A cache control for responses. Unlike :class:`RequestCacheControl`
this is mutable and gives access to response-relevant cache control
headers.
@@ -2011,11 +2010,11 @@ class ResponseCacheControl(_CacheControl):
both for request and response.
"""
- public = cache_property('public', None, bool)
- private = cache_property('private', '*', None)
- must_revalidate = cache_property('must-revalidate', None, bool)
- proxy_revalidate = cache_property('proxy-revalidate', None, bool)
- s_maxage = cache_property('s-maxage', None, None)
+ public = cache_property("public", None, bool)
+ private = cache_property("private", "*", None)
+ must_revalidate = cache_property("must-revalidate", None, bool)
+ proxy_revalidate = cache_property("proxy-revalidate", None, bool)
+ s_maxage = cache_property("s-maxage", None, None)
# attach cache_property to the _CacheControl as staticmethod
@@ -2024,7 +2023,6 @@ _CacheControl.cache_property = staticmethod(cache_property)
class CallbackDict(UpdateDictMixin, dict):
-
"""A dict that calls a function passed every time something is changed.
The function is passed the dict instance.
"""
@@ -2034,14 +2032,10 @@ class CallbackDict(UpdateDictMixin, dict):
self.on_update = on_update
def __repr__(self):
- return '<%s %s>' % (
- self.__class__.__name__,
- dict.__repr__(self)
- )
+ return "<%s %s>" % (self.__class__.__name__, dict.__repr__(self))
class HeaderSet(collections_abc.MutableSet):
-
"""Similar to the :class:`ETags` class this implements a set-like structure.
Unlike :class:`ETags` this is case insensitive and used for vary, allow, and
content-language headers.
@@ -2153,7 +2147,7 @@ class HeaderSet(collections_abc.MutableSet):
def to_header(self):
"""Convert the header set into an HTTP header string."""
- return ', '.join(map(quote_header_value, self._headers))
+ return ", ".join(map(quote_header_value, self._headers))
def __getitem__(self, idx):
return self._headers[idx]
@@ -2188,14 +2182,10 @@ class HeaderSet(collections_abc.MutableSet):
return self.to_header()
def __repr__(self):
- return '%s(%r)' % (
- self.__class__.__name__,
- self._headers
- )
+ return "%s(%r)" % (self.__class__.__name__, self._headers)
class ETags(collections_abc.Container, collections_abc.Iterable):
-
"""A set that can be used to check if one etag is present in a collection
of etags.
"""
@@ -2245,15 +2235,14 @@ class ETags(collections_abc.Container, collections_abc.Iterable):
def to_header(self):
"""Convert the etags set into a HTTP header string."""
if self.star_tag:
- return '*'
- return ', '.join(
- ['"%s"' % x for x in self._strong]
- + ['W/"%s"' % x for x in self._weak]
+ return "*"
+ return ", ".join(
+ ['"%s"' % x for x in self._strong] + ['W/"%s"' % x for x in self._weak]
)
def __call__(self, etag=None, data=None, include_weak=False):
if [etag, data].count(None) != 1:
- raise TypeError('either tag or data required, but at least one')
+ raise TypeError("either tag or data required, but at least one")
if etag is None:
etag = generate_etag(data)
if include_weak:
@@ -2276,11 +2265,10 @@ class ETags(collections_abc.Container, collections_abc.Iterable):
return self.contains(etag)
def __repr__(self):
- return '<%s %r>' % (self.__class__.__name__, str(self))
+ return "<%s %r>" % (self.__class__.__name__, str(self))
class IfRange(object):
-
"""Very simple object that represents the `If-Range` header in parsed
form. It will either have neither a etag or date or one of either but
never both.
@@ -2301,13 +2289,13 @@ class IfRange(object):
return http_date(self.date)
if self.etag is not None:
return quote_etag(self.etag)
- return ''
+ return ""
def __str__(self):
return self.to_header()
def __repr__(self):
- return '<%s %r>' % (self.__class__.__name__, str(self))
+ return "<%s %r>" % (self.__class__.__name__, str(self))
class Range(object):
@@ -2339,7 +2327,7 @@ class Range(object):
exactly one range and it is satisfiable it returns a ``(start, stop)``
tuple, otherwise `None`.
"""
- if self.units != 'bytes' or length is None or len(self.ranges) != 1:
+ if self.units != "bytes" or length is None or len(self.ranges) != 1:
return None
start, end = self.ranges[0]
if end is None:
@@ -2362,10 +2350,10 @@ class Range(object):
ranges = []
for begin, end in self.ranges:
if end is None:
- ranges.append('%s-' % begin if begin >= 0 else str(begin))
+ ranges.append("%s-" % begin if begin >= 0 else str(begin))
else:
- ranges.append('%s-%s' % (begin, end - 1))
- return '%s=%s' % (self.units, ','.join(ranges))
+ ranges.append("%s-%s" % (begin, end - 1))
+ return "%s=%s" % (self.units, ",".join(ranges))
def to_content_range_header(self, length):
"""Converts the object into `Content-Range` HTTP header,
@@ -2373,32 +2361,33 @@ class Range(object):
"""
range_for_length = self.range_for_length(length)
if range_for_length is not None:
- return '%s %d-%d/%d' % (self.units,
- range_for_length[0],
- range_for_length[1] - 1, length)
+ return "%s %d-%d/%d" % (
+ self.units,
+ range_for_length[0],
+ range_for_length[1] - 1,
+ length,
+ )
return None
def __str__(self):
return self.to_header()
def __repr__(self):
- return '<%s %r>' % (self.__class__.__name__, str(self))
+ return "<%s %r>" % (self.__class__.__name__, str(self))
class ContentRange(object):
-
"""Represents the content range header.
.. versionadded:: 0.7
"""
def __init__(self, units, start, stop, length=None, on_update=None):
- assert is_byte_range_valid(start, stop, length), \
- 'Bad range provided'
+ assert is_byte_range_valid(start, stop, length), "Bad range provided"
self.on_update = on_update
self.set(start, stop, length, units)
- def _callback_property(name):
+ def _callback_property(name): # noqa: B902
def fget(self):
return getattr(self, name)
@@ -2406,22 +2395,23 @@ class ContentRange(object):
setattr(self, name, value)
if self.on_update is not None:
self.on_update(self)
+
return property(fget, fset)
#: The units to use, usually "bytes"
- units = _callback_property('_units')
+ units = _callback_property("_units")
#: The start point of the range or `None`.
- start = _callback_property('_start')
+ start = _callback_property("_start")
#: The stop point of the range (non-inclusive) or `None`. Can only be
#: `None` if also start is `None`.
- stop = _callback_property('_stop')
+ stop = _callback_property("_stop")
#: The length of the range or `None`.
- length = _callback_property('_length')
+ length = _callback_property("_length")
+ del _callback_property
- def set(self, start, stop, length=None, units='bytes'):
+ def set(self, start, stop, length=None, units="bytes"):
"""Simple method to update the ranges."""
- assert is_byte_range_valid(start, stop, length), \
- 'Bad range provided'
+ assert is_byte_range_valid(start, stop, length), "Bad range provided"
self._units = units
self._start = start
self._stop = stop
@@ -2437,19 +2427,14 @@ class ContentRange(object):
def to_header(self):
if self.units is None:
- return ''
+ return ""
if self.length is None:
- length = '*'
+ length = "*"
else:
length = self.length
if self.start is None:
- return '%s */%s' % (self.units, length)
- return '%s %s-%s/%s' % (
- self.units,
- self.start,
- self.stop - 1,
- length
- )
+ return "%s */%s" % (self.units, length)
+ return "%s %s-%s/%s" % (self.units, self.start, self.stop - 1, length)
def __nonzero__(self):
return self.units is not None
@@ -2460,11 +2445,10 @@ class ContentRange(object):
return self.to_header()
def __repr__(self):
- return '<%s %r>' % (self.__class__.__name__, str(self))
+ return "<%s %r>" % (self.__class__.__name__, str(self))
class Authorization(ImmutableDictMixin, dict):
-
"""Represents an `Authorization` header sent by the client. You should
not create this kind of object yourself but use it when it's returned by
the `parse_authorization_header` function.
@@ -2481,77 +2465,107 @@ class Authorization(ImmutableDictMixin, dict):
dict.__init__(self, data or {})
self.type = auth_type
- username = property(lambda x: x.get('username'), doc='''
+ username = property(
+ lambda self: self.get("username"),
+ doc="""
The username transmitted. This is set for both basic and digest
- auth all the time.''')
- password = property(lambda x: x.get('password'), doc='''
+ auth all the time.""",
+ )
+ password = property(
+ lambda self: self.get("password"),
+ doc="""
When the authentication type is basic this is the password
- transmitted by the client, else `None`.''')
- realm = property(lambda x: x.get('realm'), doc='''
- This is the server realm sent back for HTTP digest auth.''')
- nonce = property(lambda x: x.get('nonce'), doc='''
+ transmitted by the client, else `None`.""",
+ )
+ realm = property(
+ lambda self: self.get("realm"),
+ doc="""
+ This is the server realm sent back for HTTP digest auth.""",
+ )
+ nonce = property(
+ lambda self: self.get("nonce"),
+ doc="""
The nonce the server sent for digest auth, sent back by the client.
A nonce should be unique for every 401 response for HTTP digest
- auth.''')
- uri = property(lambda x: x.get('uri'), doc='''
+ auth.""",
+ )
+ uri = property(
+ lambda self: self.get("uri"),
+ doc="""
The URI from Request-URI of the Request-Line; duplicated because
proxies are allowed to change the Request-Line in transit. HTTP
- digest auth only.''')
- nc = property(lambda x: x.get('nc'), doc='''
+ digest auth only.""",
+ )
+ nc = property(
+ lambda self: self.get("nc"),
+ doc="""
The nonce count value transmitted by clients if a qop-header is
- also transmitted. HTTP digest auth only.''')
- cnonce = property(lambda x: x.get('cnonce'), doc='''
+ also transmitted. HTTP digest auth only.""",
+ )
+ cnonce = property(
+ lambda self: self.get("cnonce"),
+ doc="""
If the server sent a qop-header in the ``WWW-Authenticate``
header, the client has to provide this value for HTTP digest auth.
- See the RFC for more details.''')
- response = property(lambda x: x.get('response'), doc='''
+ See the RFC for more details.""",
+ )
+ response = property(
+ lambda self: self.get("response"),
+ doc="""
A string of 32 hex digits computed as defined in RFC 2617, which
- proves that the user knows a password. Digest auth only.''')
- opaque = property(lambda x: x.get('opaque'), doc='''
+ proves that the user knows a password. Digest auth only.""",
+ )
+ opaque = property(
+ lambda self: self.get("opaque"),
+ doc="""
The opaque header from the server returned unchanged by the client.
It is recommended that this string be base64 or hexadecimal data.
- Digest auth only.''')
- qop = property(lambda x: x.get('qop'), doc='''
+ Digest auth only.""",
+ )
+ qop = property(
+ lambda self: self.get("qop"),
+ doc="""
Indicates what "quality of protection" the client has applied to
the message for HTTP digest auth. Note that this is a single token,
- not a quoted list of alternatives as in WWW-Authenticate.''')
+ not a quoted list of alternatives as in WWW-Authenticate.""",
+ )
class WWWAuthenticate(UpdateDictMixin, dict):
-
"""Provides simple access to `WWW-Authenticate` headers."""
#: list of keys that require quoting in the generated header
- _require_quoting = frozenset(['domain', 'nonce', 'opaque', 'realm', 'qop'])
+ _require_quoting = frozenset(["domain", "nonce", "opaque", "realm", "qop"])
def __init__(self, auth_type=None, values=None, on_update=None):
dict.__init__(self, values or ())
if auth_type:
- self['__auth_type__'] = auth_type
+ self["__auth_type__"] = auth_type
self.on_update = on_update
- def set_basic(self, realm='authentication required'):
+ def set_basic(self, realm="authentication required"):
"""Clear the auth info and enable basic auth."""
dict.clear(self)
- dict.update(self, {'__auth_type__': 'basic', 'realm': realm})
+ dict.update(self, {"__auth_type__": "basic", "realm": realm})
if self.on_update:
self.on_update(self)
- def set_digest(self, realm, nonce, qop=('auth',), opaque=None,
- algorithm=None, stale=False):
+ def set_digest(
+ self, realm, nonce, qop=("auth",), opaque=None, algorithm=None, stale=False
+ ):
"""Clear the auth info and enable digest auth."""
d = {
- '__auth_type__': 'digest',
- 'realm': realm,
- 'nonce': nonce,
- 'qop': dump_header(qop)
+ "__auth_type__": "digest",
+ "realm": realm,
+ "nonce": nonce,
+ "qop": dump_header(qop),
}
if stale:
- d['stale'] = 'TRUE'
+ d["stale"] = "TRUE"
if opaque is not None:
- d['opaque'] = opaque
+ d["opaque"] = opaque
if algorithm is not None:
- d['algorithm'] = algorithm
+ d["algorithm"] = algorithm
dict.clear(self)
dict.update(self, d)
if self.on_update:
@@ -2560,23 +2574,30 @@ class WWWAuthenticate(UpdateDictMixin, dict):
def to_header(self):
"""Convert the stored values into a WWW-Authenticate header."""
d = dict(self)
- auth_type = d.pop('__auth_type__', None) or 'basic'
- return '%s %s' % (auth_type.title(), ', '.join([
- '%s=%s' % (key, quote_header_value(value,
- allow_token=key not in self._require_quoting))
- for key, value in iteritems(d)
- ]))
+ auth_type = d.pop("__auth_type__", None) or "basic"
+ return "%s %s" % (
+ auth_type.title(),
+ ", ".join(
+ [
+ "%s=%s"
+ % (
+ key,
+ quote_header_value(
+ value, allow_token=key not in self._require_quoting
+ ),
+ )
+ for key, value in iteritems(d)
+ ]
+ ),
+ )
def __str__(self):
return self.to_header()
def __repr__(self):
- return '<%s %r>' % (
- self.__class__.__name__,
- self.to_header()
- )
+ return "<%s %r>" % (self.__class__.__name__, self.to_header())
- def auth_property(name, doc=None):
+ def auth_property(name, doc=None): # noqa: B902
"""A static helper function for subclasses to add extra authentication
system properties onto a class::
@@ -2592,70 +2613,90 @@ class WWWAuthenticate(UpdateDictMixin, dict):
self.pop(name, None)
else:
self[name] = str(value)
+
return property(lambda x: x.get(name), _set_value, doc=doc)
- def _set_property(name, doc=None):
+ def _set_property(name, doc=None): # noqa: B902
def fget(self):
def on_update(header_set):
if not header_set and name in self:
del self[name]
elif header_set:
self[name] = header_set.to_header()
+
return parse_set_header(self.get(name), on_update)
+
return property(fget, doc=doc)
- type = auth_property('__auth_type__', doc='''
- The type of the auth mechanism. HTTP currently specifies
- `Basic` and `Digest`.''')
- realm = auth_property('realm', doc='''
- A string to be displayed to users so they know which username and
- password to use. This string should contain at least the name of
- the host performing the authentication and might additionally
- indicate the collection of users who might have access.''')
- domain = _set_property('domain', doc='''
- A list of URIs that define the protection space. If a URI is an
- absolute path, it is relative to the canonical root URL of the
- server being accessed.''')
- nonce = auth_property('nonce', doc='''
+ type = auth_property(
+ "__auth_type__",
+ doc="""The type of the auth mechanism. HTTP currently specifies
+ ``Basic`` and ``Digest``.""",
+ )
+ realm = auth_property(
+ "realm",
+ doc="""A string to be displayed to users so they know which
+ username and password to use. This string should contain at
+ least the name of the host performing the authentication and
+ might additionally indicate the collection of users who might
+ have access.""",
+ )
+ domain = _set_property(
+ "domain",
+ doc="""A list of URIs that define the protection space. If a URI
+ is an absolute path, it is relative to the canonical root URL of
+ the server being accessed.""",
+ )
+ nonce = auth_property(
+ "nonce",
+ doc="""
A server-specified data string which should be uniquely generated
- each time a 401 response is made. It is recommended that this
- string be base64 or hexadecimal data.''')
- opaque = auth_property('opaque', doc='''
- A string of data, specified by the server, which should be returned
- by the client unchanged in the Authorization header of subsequent
- requests with URIs in the same protection space. It is recommended
- that this string be base64 or hexadecimal data.''')
- algorithm = auth_property('algorithm', doc='''
- A string indicating a pair of algorithms used to produce the digest
- and a checksum. If this is not present it is assumed to be "MD5".
- If the algorithm is not understood, the challenge should be ignored
- (and a different one used, if there is more than one).''')
- qop = _set_property('qop', doc='''
- A set of quality-of-privacy directives such as auth and auth-int.''')
-
- def _get_stale(self):
- val = self.get('stale')
+ each time a 401 response is made. It is recommended that this
+ string be base64 or hexadecimal data.""",
+ )
+ opaque = auth_property(
+ "opaque",
+ doc="""A string of data, specified by the server, which should
+ be returned by the client unchanged in the Authorization header
+ of subsequent requests with URIs in the same protection space.
+ It is recommended that this string be base64 or hexadecimal
+ data.""",
+ )
+ algorithm = auth_property(
+ "algorithm",
+ doc="""A string indicating a pair of algorithms used to produce
+ the digest and a checksum. If this is not present it is assumed
+ to be "MD5". If the algorithm is not understood, the challenge
+ should be ignored (and a different one used, if there is more
+ than one).""",
+ )
+ qop = _set_property(
+ "qop",
+ doc="""A set of quality-of-privacy directives such as auth and
+ auth-int.""",
+ )
+
+ @property
+ def stale(self):
+ """A flag, indicating that the previous request from the client
+ was rejected because the nonce value was stale.
+ """
+ val = self.get("stale")
if val is not None:
- return val.lower() == 'true'
+ return val.lower() == "true"
- def _set_stale(self, value):
+ @stale.setter
+ def stale(self, value):
if value is None:
- self.pop('stale', None)
+ self.pop("stale", None)
else:
- self['stale'] = 'TRUE' if value else 'FALSE'
- stale = property(_get_stale, _set_stale, doc='''
- A flag, indicating that the previous request from the client was
- rejected because the nonce value was stale.''')
- del _get_stale, _set_stale
-
- # make auth_property a staticmethod so that subclasses of
- # `WWWAuthenticate` can use it for new properties.
+ self["stale"] = "TRUE" if value else "FALSE"
+
auth_property = staticmethod(auth_property)
del _set_property
class FileStorage(object):
-
"""The :class:`FileStorage` class is a thin wrapper over incoming files.
It is used by the request object to represent uploaded files. All the
attributes of the wrapper stream are proxied by the file storage so
@@ -2663,9 +2704,15 @@ class FileStorage(object):
``storage.stream.read()``.
"""
- def __init__(self, stream=None, filename=None, name=None,
- content_type=None, content_length=None,
- headers=None):
+ def __init__(
+ self,
+ stream=None,
+ filename=None,
+ name=None,
+ content_type=None,
+ content_length=None,
+ headers=None,
+ ):
self.name = name
self.stream = stream or BytesIO()
@@ -2674,41 +2721,39 @@ class FileStorage(object):
# skip things like <fdopen>, <stderr> etc. Python marks these
# special filenames with angular brackets.
if filename is None:
- filename = getattr(stream, 'name', None)
+ filename = getattr(stream, "name", None)
s = make_literal_wrapper(filename)
- if filename and filename[0] == s('<') and filename[-1] == s('>'):
+ if filename and filename[0] == s("<") and filename[-1] == s(">"):
filename = None
# On Python 3 we want to make sure the filename is always unicode.
# This might not be if the name attribute is bytes due to the
# file being opened from the bytes API.
if not PY2 and isinstance(filename, bytes):
- filename = filename.decode(get_filesystem_encoding(),
- 'replace')
+ filename = filename.decode(get_filesystem_encoding(), "replace")
self.filename = filename
if headers is None:
headers = Headers()
self.headers = headers
if content_type is not None:
- headers['Content-Type'] = content_type
+ headers["Content-Type"] = content_type
if content_length is not None:
- headers['Content-Length'] = str(content_length)
+ headers["Content-Length"] = str(content_length)
def _parse_content_type(self):
- if not hasattr(self, '_parsed_content_type'):
- self._parsed_content_type = \
- parse_options_header(self.content_type)
+ if not hasattr(self, "_parsed_content_type"):
+ self._parsed_content_type = parse_options_header(self.content_type)
@property
def content_type(self):
"""The content-type sent in the header. Usually not available"""
- return self.headers.get('content-type')
+ return self.headers.get("content-type")
@property
def content_length(self):
"""The content-length sent in the header. Usually not available"""
- return int(self.headers.get('content-length') or 0)
+ return int(self.headers.get("content-length") or 0)
@property
def mimetype(self):
@@ -2748,9 +2793,10 @@ class FileStorage(object):
:func:`shutil.copyfileobj`.
"""
from shutil import copyfileobj
+
close_dst = False
if isinstance(dst, string_types):
- dst = open(dst, 'wb')
+ dst = open(dst, "wb")
close_dst = True
try:
copyfileobj(self.stream, dst, buffer_size)
@@ -2767,6 +2813,7 @@ class FileStorage(object):
def __nonzero__(self):
return bool(self.filename)
+
__bool__ = __nonzero__
def __getattr__(self, name):
@@ -2784,15 +2831,22 @@ class FileStorage(object):
return iter(self.stream)
def __repr__(self):
- return '<%s: %r (%r)>' % (
+ return "<%s: %r (%r)>" % (
self.__class__.__name__,
self.filename,
- self.content_type
+ self.content_type,
)
# circular dependencies
-from werkzeug.http import dump_options_header, dump_header, generate_etag, \
- quote_header_value, parse_set_header, unquote_etag, quote_etag, \
- parse_options_header, http_date, is_byte_range_valid
-from werkzeug import exceptions
+from . import exceptions
+from .http import dump_header
+from .http import dump_options_header
+from .http import generate_etag
+from .http import http_date
+from .http import is_byte_range_valid
+from .http import parse_options_header
+from .http import parse_set_header
+from .http import quote_etag
+from .http import quote_header_value
+from .http import unquote_etag
diff --git a/src/werkzeug/debug/__init__.py b/src/werkzeug/debug/__init__.py
index cb984b03..beb7729e 100644
--- a/src/werkzeug/debug/__init__.py
+++ b/src/werkzeug/debug/__init__.py
@@ -8,33 +8,35 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
+import getpass
+import hashlib
+import json
+import mimetypes
import os
import pkgutil
import re
import sys
-import uuid
-import json
import time
-import getpass
-import hashlib
-import mimetypes
+import uuid
from itertools import chain
-from os.path import join, basename
-from werkzeug.wrappers import BaseRequest as Request, BaseResponse as Response
-from werkzeug.http import parse_cookie
-from werkzeug.debug.tbtools import get_current_traceback, render_console_html
-from werkzeug.debug.console import Console
-from werkzeug.security import gen_salt
-from werkzeug._internal import _log
-from werkzeug._compat import text_type
+from os.path import basename
+from os.path import join
-
-# DEPRECATED
-from werkzeug.debug.repr import debug_repr as _debug_repr
+from .._compat import text_type
+from .._internal import _log
+from ..http import parse_cookie
+from ..security import gen_salt
+from ..wrappers import BaseRequest as Request
+from ..wrappers import BaseResponse as Response
+from .console import Console
+from .repr import debug_repr as _debug_repr
+from .tbtools import get_current_traceback
+from .tbtools import render_console_html
def debug_repr(*args, **kwargs):
import warnings
+
warnings.warn(
"'debug_repr' has moved to 'werkzeug.debug.repr.debug_repr'"
" as of version 0.7. This old import will be removed in version"
@@ -51,8 +53,8 @@ PIN_TIME = 60 * 60 * 24 * 7
def hash_pin(pin):
if isinstance(pin, text_type):
- pin = pin.encode('utf-8', 'replace')
- return hashlib.md5(pin + b'shittysalt').hexdigest()[:12]
+ pin = pin.encode("utf-8", "replace")
+ return hashlib.md5(pin + b"shittysalt").hexdigest()[:12]
_machine_id = None
@@ -67,9 +69,9 @@ def get_machine_id():
def _generate():
# Potential sources of secret information on linux. The machine-id
# is stable across boots, the boot id is not
- for filename in '/etc/machine-id', '/proc/sys/kernel/random/boot_id':
+ for filename in "/etc/machine-id", "/proc/sys/kernel/random/boot_id":
try:
- with open(filename, 'rb') as f:
+ with open(filename, "rb") as f:
return f.readline().strip()
except IOError:
continue
@@ -81,8 +83,10 @@ def get_machine_id():
# Google App Engine
# See https://github.com/pallets/werkzeug/issues/925
from subprocess import Popen, PIPE
- dump = Popen(['ioreg', '-c', 'IOPlatformExpertDevice', '-d', '2'],
- stdout=PIPE).communicate()[0]
+
+ dump = Popen(
+ ["ioreg", "-c", "IOPlatformExpertDevice", "-d", "2"], stdout=PIPE
+ ).communicate()[0]
match = re.search(b'"serial-number" = <([^>]+)', dump)
if match is not None:
return match.group(1)
@@ -100,12 +104,15 @@ def get_machine_id():
pass
if wr is not None:
try:
- with wr.OpenKey(wr.HKEY_LOCAL_MACHINE,
- 'SOFTWARE\\Microsoft\\Cryptography', 0,
- wr.KEY_READ | wr.KEY_WOW64_64KEY) as rk:
- machineGuid, wrType = wr.QueryValueEx(rk, 'MachineGuid')
+ with wr.OpenKey(
+ wr.HKEY_LOCAL_MACHINE,
+ "SOFTWARE\\Microsoft\\Cryptography",
+ 0,
+ wr.KEY_READ | wr.KEY_WOW64_64KEY,
+ ) as rk:
+ machineGuid, wrType = wr.QueryValueEx(rk, "MachineGuid")
if wrType == wr.REG_SZ:
- return machineGuid.encode('utf-8')
+ return machineGuid.encode("utf-8")
else:
return machineGuid
except WindowsError:
@@ -116,7 +123,6 @@ def get_machine_id():
class _ConsoleFrame(object):
-
"""Helper class so that we can reuse the frame console code for the
standalone console.
"""
@@ -134,24 +140,23 @@ def get_pin_and_cookie_name(app):
Second item in the resulting tuple is the cookie name for remembering.
"""
- pin = os.environ.get('WERKZEUG_DEBUG_PIN')
+ pin = os.environ.get("WERKZEUG_DEBUG_PIN")
rv = None
num = None
# Pin was explicitly disabled
- if pin == 'off':
+ if pin == "off":
return None, None
# Pin was provided explicitly
- if pin is not None and pin.replace('-', '').isdigit():
+ if pin is not None and pin.replace("-", "").isdigit():
# If there are separators in the pin, return it directly
- if '-' in pin:
+ if "-" in pin:
rv = pin
else:
num = pin
- modname = getattr(app, '__module__',
- getattr(app.__class__, '__module__'))
+ modname = getattr(app, "__module__", getattr(app.__class__, "__module__"))
try:
# `getpass.getuser()` imports the `pwd` module,
@@ -167,42 +172,41 @@ def get_pin_and_cookie_name(app):
probably_public_bits = [
username,
modname,
- getattr(app, '__name__', getattr(app.__class__, '__name__')),
- getattr(mod, '__file__', None),
+ getattr(app, "__name__", getattr(app.__class__, "__name__")),
+ getattr(mod, "__file__", None),
]
# This information is here to make it harder for an attacker to
# guess the cookie name. They are unlikely to be contained anywhere
# within the unauthenticated debug page.
- private_bits = [
- str(uuid.getnode()),
- get_machine_id(),
- ]
+ private_bits = [str(uuid.getnode()), get_machine_id()]
h = hashlib.md5()
for bit in chain(probably_public_bits, private_bits):
if not bit:
continue
if isinstance(bit, text_type):
- bit = bit.encode('utf-8')
+ bit = bit.encode("utf-8")
h.update(bit)
- h.update(b'cookiesalt')
+ h.update(b"cookiesalt")
- cookie_name = '__wzd' + h.hexdigest()[:20]
+ cookie_name = "__wzd" + h.hexdigest()[:20]
# If we need to generate a pin we salt it a bit more so that we don't
# end up with the same value and generate out 9 digits
if num is None:
- h.update(b'pinsalt')
- num = ('%09d' % int(h.hexdigest(), 16))[:9]
+ h.update(b"pinsalt")
+ num = ("%09d" % int(h.hexdigest(), 16))[:9]
# Format the pincode in groups of digits for easier remembering if
# we don't have a result yet.
if rv is None:
for group_size in 5, 4, 3:
if len(num) % group_size == 0:
- rv = '-'.join(num[x:x + group_size].rjust(group_size, '0')
- for x in range(0, len(num), group_size))
+ rv = "-".join(
+ num[x : x + group_size].rjust(group_size, "0")
+ for x in range(0, len(num), group_size)
+ )
break
else:
rv = num
@@ -240,12 +244,21 @@ class DebuggedApplication(object):
:param pin_logging: enables the logging of the pin system.
"""
- def __init__(self, app, evalex=False, request_key='werkzeug.request',
- console_path='/console', console_init_func=None,
- show_hidden_frames=False, lodgeit_url=None,
- pin_security=True, pin_logging=True):
+ def __init__(
+ self,
+ app,
+ evalex=False,
+ request_key="werkzeug.request",
+ console_path="/console",
+ console_init_func=None,
+ show_hidden_frames=False,
+ lodgeit_url=None,
+ pin_security=True,
+ pin_logging=True,
+ ):
if lodgeit_url is not None:
from warnings import warn
+
warn(
"'lodgeit_url' is no longer used as of version 0.9 and"
" will be removed in version 1.0. Werkzeug uses"
@@ -269,19 +282,17 @@ class DebuggedApplication(object):
self.pin_logging = pin_logging
if pin_security:
# Print out the pin for the debugger on standard out.
- if os.environ.get('WERKZEUG_RUN_MAIN') == 'true' and \
- pin_logging:
- _log('warning', ' * Debugger is active!')
+ if os.environ.get("WERKZEUG_RUN_MAIN") == "true" and pin_logging:
+ _log("warning", " * Debugger is active!")
if self.pin is None:
- _log('warning', ' * Debugger PIN disabled. '
- 'DEBUGGER UNSECURED!')
+ _log("warning", " * Debugger PIN disabled. DEBUGGER UNSECURED!")
else:
- _log('info', ' * Debugger PIN: %s' % self.pin)
+ _log("info", " * Debugger PIN: %s" % self.pin)
else:
self.pin = None
def _get_pin(self):
- if not hasattr(self, '_pin'):
+ if not hasattr(self, "_pin"):
self._pin, self._pin_cookie = get_pin_and_cookie_name(self.app)
return self._pin
@@ -294,7 +305,7 @@ class DebuggedApplication(object):
@property
def pin_cookie_name(self):
"""The name of the pin cookie."""
- if not hasattr(self, '_pin_cookie'):
+ if not hasattr(self, "_pin_cookie"):
self._pin, self._pin_cookie = get_pin_and_cookie_name(self.app)
return self._pin_cookie
@@ -305,46 +316,51 @@ class DebuggedApplication(object):
app_iter = self.app(environ, start_response)
for item in app_iter:
yield item
- if hasattr(app_iter, 'close'):
+ if hasattr(app_iter, "close"):
app_iter.close()
except Exception:
- if hasattr(app_iter, 'close'):
+ if hasattr(app_iter, "close"):
app_iter.close()
traceback = get_current_traceback(
- skip=1, show_hidden_frames=self.show_hidden_frames,
- ignore_system_exceptions=True)
+ skip=1,
+ show_hidden_frames=self.show_hidden_frames,
+ ignore_system_exceptions=True,
+ )
for frame in traceback.frames:
self.frames[frame.id] = frame
self.tracebacks[traceback.id] = traceback
try:
- start_response('500 INTERNAL SERVER ERROR', [
- ('Content-Type', 'text/html; charset=utf-8'),
- # Disable Chrome's XSS protection, the debug
- # output can cause false-positives.
- ('X-XSS-Protection', '0'),
- ])
+ start_response(
+ "500 INTERNAL SERVER ERROR",
+ [
+ ("Content-Type", "text/html; charset=utf-8"),
+ # Disable Chrome's XSS protection, the debug
+ # output can cause false-positives.
+ ("X-XSS-Protection", "0"),
+ ],
+ )
except Exception:
# if we end up here there has been output but an error
# occurred. in that situation we can do nothing fancy any
# more, better log something into the error log and fall
# back gracefully.
- environ['wsgi.errors'].write(
- 'Debugging middleware caught exception in streamed '
- 'response at a point where response headers were already '
- 'sent.\n')
+ environ["wsgi.errors"].write(
+ "Debugging middleware caught exception in streamed "
+ "response at a point where response headers were already "
+ "sent.\n"
+ )
else:
is_trusted = bool(self.check_pin_trust(environ))
- yield traceback.render_full(evalex=self.evalex,
- evalex_trusted=is_trusted,
- secret=self.secret) \
- .encode('utf-8', 'replace')
+ yield traceback.render_full(
+ evalex=self.evalex, evalex_trusted=is_trusted, secret=self.secret
+ ).encode("utf-8", "replace")
- traceback.log(environ['wsgi.errors'])
+ traceback.log(environ["wsgi.errors"])
def execute_command(self, request, command, frame):
"""Execute a command in a console."""
- return Response(frame.console.eval(command), mimetype='text/html')
+ return Response(frame.console.eval(command), mimetype="text/html")
def display_console(self, request):
"""Display a standalone shell."""
@@ -353,30 +369,30 @@ class DebuggedApplication(object):
ns = {}
else:
ns = dict(self.console_init_func())
- ns.setdefault('app', self.app)
+ ns.setdefault("app", self.app)
self.frames[0] = _ConsoleFrame(ns)
is_trusted = bool(self.check_pin_trust(request.environ))
- return Response(render_console_html(secret=self.secret,
- evalex_trusted=is_trusted),
- mimetype='text/html')
+ return Response(
+ render_console_html(secret=self.secret, evalex_trusted=is_trusted),
+ mimetype="text/html",
+ )
def paste_traceback(self, request, traceback):
"""Paste the traceback and return a JSON response."""
rv = traceback.paste()
- return Response(json.dumps(rv), mimetype='application/json')
+ return Response(json.dumps(rv), mimetype="application/json")
def get_resource(self, request, filename):
"""Return a static resource from the shared folder."""
- filename = join('shared', basename(filename))
+ filename = join("shared", basename(filename))
try:
data = pkgutil.get_data(__package__, filename)
except OSError:
data = None
if data is not None:
- mimetype = mimetypes.guess_type(filename)[0] \
- or 'application/octet-stream'
+ mimetype = mimetypes.guess_type(filename)[0] or "application/octet-stream"
return Response(data, mimetype=mimetype)
- return Response('Not Found', status=404)
+ return Response("Not Found", status=404)
def check_pin_trust(self, environ):
"""Checks if the request passed the pin test. This returns `True` if the
@@ -387,9 +403,9 @@ class DebuggedApplication(object):
if self.pin is None:
return True
val = parse_cookie(environ).get(self.pin_cookie_name)
- if not val or '|' not in val:
+ if not val or "|" not in val:
return False
- ts, pin_hash = val.split('|', 1)
+ ts, pin_hash = val.split("|", 1)
if not ts.isdigit():
return False
if pin_hash != hash_pin(self.pin):
@@ -426,23 +442,23 @@ class DebuggedApplication(object):
# Otherwise go through pin based authentication
else:
- entered_pin = request.args.get('pin')
- if entered_pin.strip().replace('-', '') == \
- self.pin.replace('-', ''):
+ entered_pin = request.args.get("pin")
+ if entered_pin.strip().replace("-", "") == self.pin.replace("-", ""):
self._failed_pin_auth = 0
auth = True
else:
self._fail_pin_auth()
- rv = Response(json.dumps({
- 'auth': auth,
- 'exhausted': exhausted,
- }), mimetype='application/json')
+ rv = Response(
+ json.dumps({"auth": auth, "exhausted": exhausted}),
+ mimetype="application/json",
+ )
if auth:
- rv.set_cookie(self.pin_cookie_name, '%s|%s' % (
- int(time.time()),
- hash_pin(self.pin)
- ), httponly=True)
+ rv.set_cookie(
+ self.pin_cookie_name,
+ "%s|%s" % (int(time.time()), hash_pin(self.pin)),
+ httponly=True,
+ )
elif bad_cookie:
rv.delete_cookie(self.pin_cookie_name)
return rv
@@ -450,10 +466,11 @@ class DebuggedApplication(object):
def log_pin_request(self):
"""Log the pin if needed."""
if self.pin_logging and self.pin is not None:
- _log('info', ' * To enable the debugger you need to '
- 'enter the security pin:')
- _log('info', ' * Debugger pin code: %s' % self.pin)
- return Response('')
+ _log(
+ "info", " * To enable the debugger you need to enter the security pin:"
+ )
+ _log("info", " * Debugger pin code: %s" % self.pin)
+ return Response("")
def __call__(self, environ, start_response):
"""Dispatch the requests."""
@@ -462,26 +479,32 @@ class DebuggedApplication(object):
# any more!
request = Request(environ)
response = self.debug_application
- if request.args.get('__debugger__') == 'yes':
- cmd = request.args.get('cmd')
- arg = request.args.get('f')
- secret = request.args.get('s')
- traceback = self.tracebacks.get(request.args.get('tb', type=int))
- frame = self.frames.get(request.args.get('frm', type=int))
- if cmd == 'resource' and arg:
+ if request.args.get("__debugger__") == "yes":
+ cmd = request.args.get("cmd")
+ arg = request.args.get("f")
+ secret = request.args.get("s")
+ traceback = self.tracebacks.get(request.args.get("tb", type=int))
+ frame = self.frames.get(request.args.get("frm", type=int))
+ if cmd == "resource" and arg:
response = self.get_resource(request, arg)
- elif cmd == 'paste' and traceback is not None and \
- secret == self.secret:
+ elif cmd == "paste" and traceback is not None and secret == self.secret:
response = self.paste_traceback(request, traceback)
- elif cmd == 'pinauth' and secret == self.secret:
+ elif cmd == "pinauth" and secret == self.secret:
response = self.pin_auth(request)
- elif cmd == 'printpin' and secret == self.secret:
+ elif cmd == "printpin" and secret == self.secret:
response = self.log_pin_request()
- elif self.evalex and cmd is not None and frame is not None \
- and self.secret == secret and \
- self.check_pin_trust(environ):
+ elif (
+ self.evalex
+ and cmd is not None
+ and frame is not None
+ and self.secret == secret
+ and self.check_pin_trust(environ)
+ ):
response = self.execute_command(request, cmd, frame)
- elif self.evalex and self.console_path is not None and \
- request.path == self.console_path:
+ elif (
+ self.evalex
+ and self.console_path is not None
+ and request.path == self.console_path
+ ):
response = self.display_console(request)
return response(environ, start_response)
diff --git a/src/werkzeug/debug/console.py b/src/werkzeug/debug/console.py
index 2915aea1..adbd170b 100644
--- a/src/werkzeug/debug/console.py
+++ b/src/werkzeug/debug/console.py
@@ -8,20 +8,21 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-import sys
import code
+import sys
from types import CodeType
-from werkzeug.utils import escape
-from werkzeug.local import Local
-from werkzeug.debug.repr import debug_repr, dump, helper
+from ..local import Local
+from ..utils import escape
+from .repr import debug_repr
+from .repr import dump
+from .repr import helper
_local = Local()
class HTMLStringO(object):
-
"""A StringO version that HTML escapes on write."""
def __init__(self):
@@ -41,46 +42,46 @@ class HTMLStringO(object):
def readline(self):
if len(self._buffer) == 0:
- return ''
+ return ""
ret = self._buffer[0]
del self._buffer[0]
return ret
def reset(self):
- val = ''.join(self._buffer)
+ val = "".join(self._buffer)
del self._buffer[:]
return val
def _write(self, x):
if isinstance(x, bytes):
- x = x.decode('utf-8', 'replace')
+ x = x.decode("utf-8", "replace")
self._buffer.append(x)
def write(self, x):
self._write(escape(x))
def writelines(self, x):
- self._write(escape(''.join(x)))
+ self._write(escape("".join(x)))
class ThreadedStream(object):
-
"""Thread-local wrapper for sys.stdout for the interactive console."""
+ @staticmethod
def push():
if not isinstance(sys.stdout, ThreadedStream):
sys.stdout = ThreadedStream()
_local.stream = HTMLStringO()
- push = staticmethod(push)
+ @staticmethod
def fetch():
try:
stream = _local.stream
except AttributeError:
- return ''
+ return ""
return stream.reset()
- fetch = staticmethod(fetch)
+ @staticmethod
def displayhook(obj):
try:
stream = _local.stream
@@ -89,18 +90,17 @@ class ThreadedStream(object):
# stream._write bypasses escaping as debug_repr is
# already generating HTML for us.
if obj is not None:
- _local._current_ipy.locals['_'] = obj
+ _local._current_ipy.locals["_"] = obj
stream._write(debug_repr(obj))
- displayhook = staticmethod(displayhook)
def __setattr__(self, name, value):
- raise AttributeError('read only attribute %s' % name)
+ raise AttributeError("read only attribute %s" % name)
def __dir__(self):
return dir(sys.__stdout__)
def __getattribute__(self, name):
- if name == '__members__':
+ if name == "__members__":
return dir(sys.__stdout__)
try:
stream = _local.stream
@@ -118,7 +118,6 @@ sys.displayhook = ThreadedStream.displayhook
class _ConsoleLoader(object):
-
def __init__(self):
self._storage = {}
@@ -143,29 +142,30 @@ def _wrap_compiler(console):
code = compile(source, filename, symbol)
console.loader.register(code, source)
return code
+
console.compile = func
class _InteractiveConsole(code.InteractiveInterpreter):
-
def __init__(self, globals, locals):
code.InteractiveInterpreter.__init__(self, locals)
self.globals = dict(globals)
- self.globals['dump'] = dump
- self.globals['help'] = helper
- self.globals['__loader__'] = self.loader = _ConsoleLoader()
+ self.globals["dump"] = dump
+ self.globals["help"] = helper
+ self.globals["__loader__"] = self.loader = _ConsoleLoader()
self.more = False
self.buffer = []
_wrap_compiler(self)
def runsource(self, source):
- source = source.rstrip() + '\n'
+ source = source.rstrip() + "\n"
ThreadedStream.push()
- prompt = '... ' if self.more else '>>> '
+ prompt = "... " if self.more else ">>> "
try:
- source_to_eval = ''.join(self.buffer + [source])
- if code.InteractiveInterpreter.runsource(self,
- source_to_eval, '<debugger>', 'single'):
+ source_to_eval = "".join(self.buffer + [source])
+ if code.InteractiveInterpreter.runsource(
+ self, source_to_eval, "<debugger>", "single"
+ ):
self.more = True
self.buffer.append(source)
else:
@@ -182,12 +182,14 @@ class _InteractiveConsole(code.InteractiveInterpreter):
self.showtraceback()
def showtraceback(self):
- from werkzeug.debug.tbtools import get_current_traceback
+ from .tbtools import get_current_traceback
+
tb = get_current_traceback(skip=1)
sys.stdout._write(tb.render_summary())
def showsyntaxerror(self, filename=None):
- from werkzeug.debug.tbtools import get_current_traceback
+ from .tbtools import get_current_traceback
+
tb = get_current_traceback(skip=4)
sys.stdout._write(tb.render_summary())
@@ -196,7 +198,6 @@ class _InteractiveConsole(code.InteractiveInterpreter):
class Console(object):
-
"""An interactive console."""
def __init__(self, globals=None, locals=None):
diff --git a/src/werkzeug/debug/repr.py b/src/werkzeug/debug/repr.py
index e82d87c4..d7a7285c 100644
--- a/src/werkzeug/debug/repr.py
+++ b/src/werkzeug/debug/repr.py
@@ -13,37 +13,38 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-import sys
-import re
import codecs
+import re
+import sys
+from collections import deque
from traceback import format_exception_only
-try:
- from collections import deque
-except ImportError: # pragma: no cover
- deque = None
-from werkzeug.utils import escape
-from werkzeug._compat import iteritems, PY2, text_type, integer_types, \
- string_types
+
+from .._compat import integer_types
+from .._compat import iteritems
+from .._compat import PY2
+from .._compat import string_types
+from .._compat import text_type
+from ..utils import escape
missing = object()
-_paragraph_re = re.compile(r'(?:\r\n|\r|\n){2,}')
+_paragraph_re = re.compile(r"(?:\r\n|\r|\n){2,}")
RegexType = type(_paragraph_re)
-HELP_HTML = '''\
+HELP_HTML = """\
<div class=box>
<h3>%(title)s</h3>
<pre class=help>%(text)s</pre>
</div>\
-'''
-OBJECT_DUMP_HTML = '''\
+"""
+OBJECT_DUMP_HTML = """\
<div class=box>
<h3>%(title)s</h3>
%(repr)s
<table>%(items)s</table>
</div>\
-'''
+"""
def debug_repr(obj):
@@ -64,31 +65,31 @@ def dump(obj=missing):
class _Helper(object):
-
"""Displays an HTML version of the normal help, for the interactive
debugger only because it requires a patched sys.stdout.
"""
def __repr__(self):
- return 'Type help(object) for help about object.'
+ return "Type help(object) for help about object."
def __call__(self, topic=None):
if topic is None:
- sys.stdout._write('<span class=help>%s</span>' % repr(self))
+ sys.stdout._write("<span class=help>%s</span>" % repr(self))
return
import pydoc
+
pydoc.help(topic)
rv = sys.stdout.reset()
if isinstance(rv, bytes):
- rv = rv.decode('utf-8', 'ignore')
+ rv = rv.decode("utf-8", "ignore")
paragraphs = _paragraph_re.split(rv)
if len(paragraphs) > 1:
title = paragraphs[0]
- text = '\n\n'.join(paragraphs[1:])
+ text = "\n\n".join(paragraphs[1:])
else: # pragma: no cover
- title = 'Help'
+ title = "Help"
text = paragraphs[0]
- sys.stdout._write(HELP_HTML % {'title': title, 'text': text})
+ sys.stdout._write(HELP_HTML % {"title": title, "text": text})
helper = _Helper()
@@ -101,55 +102,55 @@ def _add_subclass_info(inner, obj, base):
return inner
elif type(obj) is base:
return inner
- module = ''
- if obj.__class__.__module__ not in ('__builtin__', 'exceptions'):
+ module = ""
+ if obj.__class__.__module__ not in ("__builtin__", "exceptions"):
module = '<span class="module">%s.</span>' % obj.__class__.__module__
- return '%s%s(%s)' % (module, obj.__class__.__name__, inner)
+ return "%s%s(%s)" % (module, obj.__class__.__name__, inner)
class DebugReprGenerator(object):
-
def __init__(self):
self._stack = []
- def _sequence_repr_maker(left, right, base=object(), limit=8):
+ def _sequence_repr_maker(left, right, base=object(), limit=8): # noqa: B008, B902
def proxy(self, obj, recursive):
if recursive:
- return _add_subclass_info(left + '...' + right, obj, base)
+ return _add_subclass_info(left + "..." + right, obj, base)
buf = [left]
have_extended_section = False
for idx, item in enumerate(obj):
if idx:
- buf.append(', ')
+ buf.append(", ")
if idx == limit:
buf.append('<span class="extended">')
have_extended_section = True
buf.append(self.repr(item))
if have_extended_section:
- buf.append('</span>')
+ buf.append("</span>")
buf.append(right)
- return _add_subclass_info(u''.join(buf), obj, base)
+ return _add_subclass_info(u"".join(buf), obj, base)
+
return proxy
- list_repr = _sequence_repr_maker('[', ']', list)
- tuple_repr = _sequence_repr_maker('(', ')', tuple)
- set_repr = _sequence_repr_maker('set([', '])', set)
- frozenset_repr = _sequence_repr_maker('frozenset([', '])', frozenset)
- if deque is not None:
- deque_repr = _sequence_repr_maker('<span class="module">collections.'
- '</span>deque([', '])', deque)
+ list_repr = _sequence_repr_maker("[", "]", list)
+ tuple_repr = _sequence_repr_maker("(", ")", tuple)
+ set_repr = _sequence_repr_maker("set([", "])", set)
+ frozenset_repr = _sequence_repr_maker("frozenset([", "])", frozenset)
+ deque_repr = _sequence_repr_maker(
+ '<span class="module">collections.' "</span>deque([", "])", deque
+ )
del _sequence_repr_maker
def regex_repr(self, obj):
pattern = repr(obj.pattern)
if PY2:
- pattern = pattern.decode('string-escape', 'ignore')
+ pattern = pattern.decode("string-escape", "ignore")
else:
- pattern = codecs.decode(pattern, 'unicode-escape', 'ignore')
- if pattern[:1] == 'u':
- pattern = 'ur' + pattern[1:]
+ pattern = codecs.decode(pattern, "unicode-escape", "ignore")
+ if pattern[:1] == "u":
+ pattern = "ur" + pattern[1:]
else:
- pattern = 'r' + pattern
+ pattern = "r" + pattern
return u're.compile(<span class="string regex">%s</span>)' % pattern
def string_repr(self, obj, limit=70):
@@ -158,14 +159,18 @@ class DebugReprGenerator(object):
# shorten the repr when the hidden part would be at least 3 chars
if len(r) - limit > 2:
- buf.extend((
- escape(r[:limit]),
- '<span class="extended">', escape(r[limit:]), '</span>',
- ))
+ buf.extend(
+ (
+ escape(r[:limit]),
+ '<span class="extended">',
+ escape(r[limit:]),
+ "</span>",
+ )
+ )
else:
buf.append(escape(r))
- buf.append('</span>')
+ buf.append("</span>")
out = u"".join(buf)
# if the repr looks like a standard string, add subclass info if needed
@@ -177,27 +182,29 @@ class DebugReprGenerator(object):
def dict_repr(self, d, recursive, limit=5):
if recursive:
- return _add_subclass_info(u'{...}', d, dict)
- buf = ['{']
+ return _add_subclass_info(u"{...}", d, dict)
+ buf = ["{"]
have_extended_section = False
for idx, (key, value) in enumerate(iteritems(d)):
if idx:
- buf.append(', ')
+ buf.append(", ")
if idx == limit - 1:
buf.append('<span class="extended">')
have_extended_section = True
- buf.append('<span class="pair"><span class="key">%s</span>: '
- '<span class="value">%s</span></span>' %
- (self.repr(key), self.repr(value)))
+ buf.append(
+ '<span class="pair"><span class="key">%s</span>: '
+ '<span class="value">%s</span></span>'
+ % (self.repr(key), self.repr(value))
+ )
if have_extended_section:
- buf.append('</span>')
- buf.append('}')
- return _add_subclass_info(u''.join(buf), d, dict)
+ buf.append("</span>")
+ buf.append("}")
+ return _add_subclass_info(u"".join(buf), d, dict)
def object_repr(self, obj):
r = repr(obj)
if PY2:
- r = r.decode('utf-8', 'replace')
+ r = r.decode("utf-8", "replace")
return u'<span class="object">%s</span>' % escape(r)
def dispatch_repr(self, obj, recursive):
@@ -225,13 +232,14 @@ class DebugReprGenerator(object):
def fallback_repr(self):
try:
- info = ''.join(format_exception_only(*sys.exc_info()[:2]))
+ info = "".join(format_exception_only(*sys.exc_info()[:2]))
except Exception: # pragma: no cover
- info = '?'
+ info = "?"
if PY2:
- info = info.decode('utf-8', 'ignore')
- return u'<span class="brokenrepr">&lt;broken repr (%s)&gt;' \
- u'</span>' % escape(info.strip())
+ info = info.decode("utf-8", "ignore")
+ return u'<span class="brokenrepr">&lt;broken repr (%s)&gt;' u"</span>" % escape(
+ info.strip()
+ )
def repr(self, obj):
recursive = False
@@ -251,7 +259,7 @@ class DebugReprGenerator(object):
def dump_object(self, obj):
repr = items = None
if isinstance(obj, dict):
- title = 'Contents of'
+ title = "Contents of"
items = []
for key, value in iteritems(obj):
if not isinstance(key, string_types):
@@ -266,23 +274,24 @@ class DebugReprGenerator(object):
items.append((key, self.repr(getattr(obj, key))))
except Exception:
pass
- title = 'Details for'
- title += ' ' + object.__repr__(obj)[1:-1]
+ title = "Details for"
+ title += " " + object.__repr__(obj)[1:-1]
return self.render_object_dump(items, title, repr)
def dump_locals(self, d):
items = [(key, self.repr(value)) for key, value in d.items()]
- return self.render_object_dump(items, 'Local variables in frame')
+ return self.render_object_dump(items, "Local variables in frame")
def render_object_dump(self, items, title, repr=None):
html_items = []
for key, value in items:
- html_items.append('<tr><th>%s<td><pre class=repr>%s</pre>' %
- (escape(key), value))
+ html_items.append(
+ "<tr><th>%s<td><pre class=repr>%s</pre>" % (escape(key), value)
+ )
if not html_items:
- html_items.append('<tr><td><em>Nothing</em>')
+ html_items.append("<tr><td><em>Nothing</em>")
return OBJECT_DUMP_HTML % {
- 'title': escape(title),
- 'repr': '<pre class=repr>%s</pre>' % repr if repr else '',
- 'items': '\n'.join(html_items)
+ "title": escape(title),
+ "repr": "<pre class=repr>%s</pre>" % repr if repr else "",
+ "items": "\n".join(html_items),
}
diff --git a/src/werkzeug/debug/tbtools.py b/src/werkzeug/debug/tbtools.py
index 9422052f..f6af4e30 100644
--- a/src/werkzeug/debug/tbtools.py
+++ b/src/werkzeug/debug/tbtools.py
@@ -8,30 +8,33 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-import re
-
+import codecs
+import inspect
+import json
import os
+import re
import sys
-import json
-import inspect
import sysconfig
import traceback
-import codecs
from tokenize import TokenError
-from werkzeug.utils import cached_property, escape
-from werkzeug.debug.console import Console
-from werkzeug._compat import (
- range_type, PY2, text_type, string_types,
- to_native, to_unicode, reraise,
-)
-from werkzeug.filesystem import get_filesystem_encoding
+from .._compat import PY2
+from .._compat import range_type
+from .._compat import reraise
+from .._compat import string_types
+from .._compat import text_type
+from .._compat import to_native
+from .._compat import to_unicode
+from ..filesystem import get_filesystem_encoding
+from ..utils import cached_property
+from ..utils import escape
+from .console import Console
-_coding_re = re.compile(br'coding[:=]\s*([-\w.]+)')
-_line_re = re.compile(br'^(.*?)$', re.MULTILINE)
-_funcdef_re = re.compile(r'^(\s*def\s)|(.*(?<!\w)lambda(:|\s))|^(\s*@)')
-UTF8_COOKIE = b'\xef\xbb\xbf'
+_coding_re = re.compile(br"coding[:=]\s*([-\w.]+)")
+_line_re = re.compile(br"^(.*?)$", re.MULTILINE)
+_funcdef_re = re.compile(r"^(\s*def\s)|(.*(?<!\w)lambda(:|\s))|^(\s*@)")
+UTF8_COOKIE = b"\xef\xbb\xbf"
system_exceptions = (SystemExit, KeyboardInterrupt)
try:
@@ -40,7 +43,7 @@ except NameError:
pass
-HEADER = u'''\
+HEADER = u"""\
<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN"
"http://www.w3.org/TR/html4/loose.dtd">
<html>
@@ -65,8 +68,8 @@ HEADER = u'''\
</head>
<body style="background-color: #fff">
<div class="debugger">
-'''
-FOOTER = u'''\
+"""
+FOOTER = u"""\
<div class="footer">
Brought to you by <strong class="arthur">DON'T PANIC</strong>, your
friendly Werkzeug powered traceback interpreter.
@@ -89,9 +92,11 @@ FOOTER = u'''\
</div>
</body>
</html>
-'''
+"""
-PAGE_HTML = HEADER + u'''\
+PAGE_HTML = (
+ HEADER
+ + u"""\
<h1>%(exception_type)s</h1>
<div class="detail">
<p class="errormsg">%(exception)s</p>
@@ -117,61 +122,69 @@ PAGE_HTML = HEADER + u'''\
execution (if the evalex feature is enabled), automatic pasting of the
exceptions and much more.</span>
</div>
-''' + FOOTER + '''
+"""
+ + FOOTER
+ + """
<!--
%(plaintext_cs)s
-->
-'''
+"""
+)
-CONSOLE_HTML = HEADER + u'''\
+CONSOLE_HTML = (
+ HEADER
+ + u"""\
<h1>Interactive Console</h1>
<div class="explanation">
In this console you can execute Python expressions in the context of the
application. The initial namespace was created by the debugger automatically.
</div>
<div class="console"><div class="inner">The Console requires JavaScript.</div></div>
-''' + FOOTER
+"""
+ + FOOTER
+)
-SUMMARY_HTML = u'''\
+SUMMARY_HTML = u"""\
<div class="%(classes)s">
%(title)s
<ul>%(frames)s</ul>
%(description)s
</div>
-'''
+"""
-FRAME_HTML = u'''\
+FRAME_HTML = u"""\
<div class="frame" id="frame-%(id)d">
<h4>File <cite class="filename">"%(filename)s"</cite>,
line <em class="line">%(lineno)s</em>,
in <code class="function">%(function_name)s</code></h4>
<div class="source %(library)s">%(lines)s</div>
</div>
-'''
+"""
-SOURCE_LINE_HTML = u'''\
+SOURCE_LINE_HTML = u"""\
<tr class="%(classes)s">
<td class=lineno>%(lineno)s</td>
<td>%(code)s</td>
</tr>
-'''
+"""
def render_console_html(secret, evalex_trusted=True):
return CONSOLE_HTML % {
- 'evalex': 'true',
- 'evalex_trusted': 'true' if evalex_trusted else 'false',
- 'console': 'true',
- 'title': 'Console',
- 'secret': secret,
- 'traceback_id': -1
+ "evalex": "true",
+ "evalex_trusted": "true" if evalex_trusted else "false",
+ "console": "true",
+ "title": "Console",
+ "secret": secret,
+ "traceback_id": -1,
}
-def get_current_traceback(ignore_system_exceptions=False,
- show_hidden_frames=False, skip=0):
+def get_current_traceback(
+ ignore_system_exceptions=False, show_hidden_frames=False, skip=0
+):
"""Get the current exception info as `Traceback` object. Per default
calling this method will reraise system exceptions such as generator exit,
system exit or others. This behavior can be disabled by passing `False`
@@ -192,7 +205,8 @@ def get_current_traceback(ignore_system_exceptions=False,
class Line(object):
"""Helper for the source renderer."""
- __slots__ = ('lineno', 'code', 'in_frame', 'current')
+
+ __slots__ = ("lineno", "code", "in_frame", "current")
def __init__(self, lineno, code):
self.lineno = lineno
@@ -202,18 +216,18 @@ class Line(object):
@property
def classes(self):
- rv = ['line']
+ rv = ["line"]
if self.in_frame:
- rv.append('in-frame')
+ rv.append("in-frame")
if self.current:
- rv.append('current')
+ rv.append("current")
return rv
def render(self):
return SOURCE_LINE_HTML % {
- 'classes': u' '.join(self.classes),
- 'lineno': self.lineno,
- 'code': escape(self.code)
+ "classes": u" ".join(self.classes),
+ "lineno": self.lineno,
+ "code": escape(self.code),
}
@@ -227,7 +241,7 @@ class Traceback(object):
exception_type = exc_type.__name__
if exc_type.__module__ not in {"builtins", "__builtin__", "exceptions"}:
- exception_type = exc_type.__module__ + '.' + exception_type
+ exception_type = exc_type.__module__ + "." + exception_type
self.exception_type = exception_type
self.groups = []
@@ -264,38 +278,33 @@ class Traceback(object):
"""Log the ASCII traceback into a file object."""
if logfile is None:
logfile = sys.stderr
- tb = self.plaintext.rstrip() + u'\n'
+ tb = self.plaintext.rstrip() + u"\n"
logfile.write(to_native(tb, "utf-8", "replace"))
def paste(self):
"""Create a paste and return the paste id."""
- data = json.dumps({
- 'description': 'Werkzeug Internal Server Error',
- 'public': False,
- 'files': {
- 'traceback.txt': {
- 'content': self.plaintext
- }
+ data = json.dumps(
+ {
+ "description": "Werkzeug Internal Server Error",
+ "public": False,
+ "files": {"traceback.txt": {"content": self.plaintext}},
}
- }).encode('utf-8')
+ ).encode("utf-8")
try:
from urllib2 import urlopen
except ImportError:
from urllib.request import urlopen
- rv = urlopen('https://api.github.com/gists', data=data)
- resp = json.loads(rv.read().decode('utf-8'))
+ rv = urlopen("https://api.github.com/gists", data=data)
+ resp = json.loads(rv.read().decode("utf-8"))
rv.close()
- return {
- 'url': resp['html_url'],
- 'id': resp['id']
- }
+ return {"url": resp["html_url"], "id": resp["id"]}
def render_summary(self, include_title=True):
"""Render the traceback for the interactive console."""
- title = ''
- classes = ['traceback']
+ title = ""
+ classes = ["traceback"]
if not self.frames:
- classes.append('noframe-traceback')
+ classes.append("noframe-traceback")
frames = []
else:
library_frames = sum(frame.is_library for frame in self.frames)
@@ -304,38 +313,37 @@ class Traceback(object):
if include_title:
if self.is_syntax_error:
- title = u'Syntax Error'
+ title = u"Syntax Error"
else:
- title = u'Traceback <em>(most recent call last)</em>:'
+ title = u"Traceback <em>(most recent call last)</em>:"
if self.is_syntax_error:
- description_wrapper = u'<pre class=syntaxerror>%s</pre>'
+ description_wrapper = u"<pre class=syntaxerror>%s</pre>"
else:
- description_wrapper = u'<blockquote>%s</blockquote>'
+ description_wrapper = u"<blockquote>%s</blockquote>"
return SUMMARY_HTML % {
- 'classes': u' '.join(classes),
- 'title': u'<h3>%s</h3>' % title if title else u'',
- 'frames': u'\n'.join(frames),
- 'description': description_wrapper % escape(self.exception)
+ "classes": u" ".join(classes),
+ "title": u"<h3>%s</h3>" % title if title else u"",
+ "frames": u"\n".join(frames),
+ "description": description_wrapper % escape(self.exception),
}
- def render_full(self, evalex=False, secret=None,
- evalex_trusted=True):
+ def render_full(self, evalex=False, secret=None, evalex_trusted=True):
"""Render the Full HTML page with the traceback info."""
exc = escape(self.exception)
return PAGE_HTML % {
- 'evalex': 'true' if evalex else 'false',
- 'evalex_trusted': 'true' if evalex_trusted else 'false',
- 'console': 'false',
- 'title': exc,
- 'exception': exc,
- 'exception_type': escape(self.exception_type),
- 'summary': self.render_summary(include_title=False),
- 'plaintext': escape(self.plaintext),
- 'plaintext_cs': re.sub('-{2,}', '-', self.plaintext),
- 'traceback_id': self.id,
- 'secret': secret
+ "evalex": "true" if evalex else "false",
+ "evalex_trusted": "true" if evalex_trusted else "false",
+ "console": "false",
+ "title": exc,
+ "exception": exc,
+ "exception_type": escape(self.exception_type),
+ "summary": self.render_summary(include_title=False),
+ "plaintext": escape(self.plaintext),
+ "plaintext_cs": re.sub("-{2,}", "-", self.plaintext),
+ "traceback_id": self.id,
+ "secret": secret,
}
@cached_property
@@ -418,10 +426,13 @@ class Group(object):
if self.info is not None:
out.append(u'<li><div class="exc-divider">%s:</div>' % self.info)
for frame in self.frames:
- out.append(u"<li%s>%s" % (
- u' title="%s"' % escape(frame.info) if frame.info else u"",
- frame.render(mark_lib=mark_lib)
- ))
+ out.append(
+ u"<li%s>%s"
+ % (
+ u' title="%s"' % escape(frame.info) if frame.info else u"",
+ frame.render(mark_lib=mark_lib),
+ )
+ )
return u"\n".join(out)
def render_text(self):
@@ -436,7 +447,6 @@ class Group(object):
class Frame(object):
-
"""A single frame in a traceback."""
def __init__(self, exc_type, exc_value, tb):
@@ -446,19 +456,19 @@ class Frame(object):
self.globals = tb.tb_frame.f_globals
fn = inspect.getsourcefile(tb) or inspect.getfile(tb)
- if fn[-4:] in ('.pyo', '.pyc'):
+ if fn[-4:] in (".pyo", ".pyc"):
fn = fn[:-1]
# if it's a file on the file system resolve the real filename.
if os.path.isfile(fn):
fn = os.path.realpath(fn)
self.filename = to_unicode(fn, get_filesystem_encoding())
- self.module = self.globals.get('__name__')
- self.loader = self.globals.get('__loader__')
+ self.module = self.globals.get("__name__")
+ self.loader = self.globals.get("__loader__")
self.code = tb.tb_frame.f_code
# support for paste's traceback extensions
- self.hide = self.locals.get('__traceback_hide__', False)
- info = self.locals.get('__traceback_info__')
+ self.hide = self.locals.get("__traceback_hide__", False)
+ info = self.locals.get("__traceback_info__")
if info is not None:
info = to_unicode(info, "utf-8", "replace")
self.info = info
@@ -466,11 +476,11 @@ class Frame(object):
def render(self, mark_lib=True):
"""Render a single frame in a traceback."""
return FRAME_HTML % {
- 'id': self.id,
- 'filename': escape(self.filename),
- 'lineno': self.lineno,
- 'function_name': escape(self.function_name),
- 'lines': self.render_line_context(),
+ "id": self.id,
+ "filename": escape(self.filename),
+ "lineno": self.lineno,
+ "function_name": escape(self.function_name),
+ "lines": self.render_line_context(),
"library": "library" if mark_lib and self.is_library else "",
}
@@ -485,7 +495,7 @@ class Frame(object):
self.filename,
self.lineno,
self.function_name,
- self.current_line.strip()
+ self.current_line.strip(),
)
def render_line_context(self):
@@ -497,34 +507,34 @@ class Frame(object):
stripped_line = line.strip()
prefix = len(line) - len(stripped_line)
rv.append(
- '<pre class="line %s"><span class="ws">%s</span>%s</pre>' % (
- cls, ' ' * prefix, escape(stripped_line) or ' '))
+ '<pre class="line %s"><span class="ws">%s</span>%s</pre>'
+ % (cls, " " * prefix, escape(stripped_line) or " ")
+ )
for line in before:
- render_line(line, 'before')
- render_line(current, 'current')
+ render_line(line, "before")
+ render_line(current, "current")
for line in after:
- render_line(line, 'after')
+ render_line(line, "after")
- return '\n'.join(rv)
+ return "\n".join(rv)
def get_annotated_lines(self):
"""Helper function that returns lines with extra information."""
lines = [Line(idx + 1, x) for idx, x in enumerate(self.sourcelines)]
# find function definition and mark lines
- if hasattr(self.code, 'co_firstlineno'):
+ if hasattr(self.code, "co_firstlineno"):
lineno = self.code.co_firstlineno - 1
while lineno > 0:
if _funcdef_re.match(lines[lineno].code):
break
lineno -= 1
try:
- offset = len(inspect.getblock([x.code + '\n' for x
- in lines[lineno:]]))
+ offset = len(inspect.getblock([x.code + "\n" for x in lines[lineno:]]))
except TokenError:
offset = 0
- for line in lines[lineno:lineno + offset]:
+ for line in lines[lineno : lineno + offset]:
line.in_frame = True
# mark current line
@@ -535,12 +545,12 @@ class Frame(object):
return lines
- def eval(self, code, mode='single'):
+ def eval(self, code, mode="single"):
"""Evaluate code in the context of the frame."""
if isinstance(code, string_types):
if PY2 and isinstance(code, text_type): # noqa
- code = UTF8_COOKIE + code.encode('utf-8')
- code = compile(code, '<interactive>', mode)
+ code = UTF8_COOKIE + code.encode("utf-8")
+ code = compile(code, "<interactive>", mode)
return eval(code, self.globals, self.locals)
@cached_property
@@ -550,9 +560,9 @@ class Frame(object):
source = None
if self.loader is not None:
try:
- if hasattr(self.loader, 'get_source'):
+ if hasattr(self.loader, "get_source"):
source = self.loader.get_source(self.module)
- elif hasattr(self.loader, 'get_source_by_code'):
+ elif hasattr(self.loader, "get_source_by_code"):
source = self.loader.get_source_by_code(self.code)
except Exception:
# we munch the exception so that we don't cause troubles
@@ -561,8 +571,7 @@ class Frame(object):
if source is None:
try:
- f = open(to_native(self.filename, get_filesystem_encoding()),
- mode='rb')
+ f = open(to_native(self.filename, get_filesystem_encoding()), mode="rb")
except IOError:
return []
try:
@@ -576,7 +585,7 @@ class Frame(object):
# yes. it should be ascii, but we don't want to reject too many
# characters in the debugger if something breaks
- charset = 'utf-8'
+ charset = "utf-8"
if source.startswith(UTF8_COOKIE):
source = source[3:]
else:
@@ -593,25 +602,21 @@ class Frame(object):
try:
codecs.lookup(charset)
except LookupError:
- charset = 'utf-8'
+ charset = "utf-8"
- return source.decode(charset, 'replace').splitlines()
+ return source.decode(charset, "replace").splitlines()
def get_context_lines(self, context=5):
- before = self.sourcelines[self.lineno - context - 1:self.lineno - 1]
- past = self.sourcelines[self.lineno:self.lineno + context]
- return (
- before,
- self.current_line,
- past,
- )
+ before = self.sourcelines[self.lineno - context - 1 : self.lineno - 1]
+ past = self.sourcelines[self.lineno : self.lineno + context]
+ return (before, self.current_line, past)
@property
def current_line(self):
try:
return self.sourcelines[self.lineno - 1]
except IndexError:
- return u''
+ return u""
@cached_property
def console(self):
diff --git a/src/werkzeug/exceptions.py b/src/werkzeug/exceptions.py
index e2c2cb3e..0e6d1ce3 100644
--- a/src/werkzeug/exceptions.py
+++ b/src/werkzeug/exceptions.py
@@ -59,23 +59,23 @@
"""
import sys
+import werkzeug
+
# Because of bootstrapping reasons we need to manually patch ourselves
# onto our parent module.
-import werkzeug
werkzeug.exceptions = sys.modules[__name__]
-from werkzeug._internal import _get_environ
-from werkzeug._compat import iteritems, integer_types, text_type, \
- implements_to_string
-
-from werkzeug.wrappers import Response
+from ._compat import implements_to_string
+from ._compat import integer_types
+from ._compat import iteritems
+from ._compat import text_type
+from ._internal import _get_environ
+from .wrappers import Response
@implements_to_string
class HTTPException(Exception):
-
- """
- Baseclass for all HTTP exceptions. This exception can be called as WSGI
+ """Baseclass for all HTTP exceptions. This exception can be called as WSGI
application to render a default error page or you can catch the subclasses
of it independently and render nicer error messages.
"""
@@ -102,6 +102,7 @@ class HTTPException(Exception):
.. versionchanged:: 0.15
The description includes the wrapped exception message.
"""
+
class newcls(cls, exception):
def __init__(self, arg=None, *args, **kwargs):
super(cls, self).__init__(*args, **kwargs)
@@ -116,41 +117,43 @@ class HTTPException(Exception):
if self.args:
out += "<p><pre><code>{}: {}</code></pre></p>".format(
- exception.__name__,
- escape(exception.__str__(self))
+ exception.__name__, escape(exception.__str__(self))
)
return out
- newcls.__module__ = sys._getframe(1).f_globals.get('__name__')
+ newcls.__module__ = sys._getframe(1).f_globals.get("__name__")
newcls.__name__ = name or cls.__name__ + exception.__name__
return newcls
@property
def name(self):
"""The status name."""
- return HTTP_STATUS_CODES.get(self.code, 'Unknown Error')
+ return HTTP_STATUS_CODES.get(self.code, "Unknown Error")
def get_description(self, environ=None):
"""Get the description."""
- return u'<p>%s</p>' % escape(self.description)
+ return u"<p>%s</p>" % escape(self.description)
def get_body(self, environ=None):
"""Get the HTML body."""
- return text_type((
- u'<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 3.2 Final//EN">\n'
- u'<title>%(code)s %(name)s</title>\n'
- u'<h1>%(name)s</h1>\n'
- u'%(description)s\n'
- ) % {
- 'code': self.code,
- 'name': escape(self.name),
- 'description': self.get_description(environ)
- })
+ return text_type(
+ (
+ u'<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 3.2 Final//EN">\n'
+ u"<title>%(code)s %(name)s</title>\n"
+ u"<h1>%(name)s</h1>\n"
+ u"%(description)s\n"
+ )
+ % {
+ "code": self.code,
+ "name": escape(self.name),
+ "description": self.get_description(environ),
+ }
+ )
def get_headers(self, environ=None):
"""Get a list of headers."""
- return [('Content-Type', 'text/html')]
+ return [("Content-Type", "text/html")]
def get_response(self, environ=None):
"""Get a response object. If one was passed to the exception
@@ -179,30 +182,29 @@ class HTTPException(Exception):
return response(environ, start_response)
def __str__(self):
- code = self.code if self.code is not None else '???'
- return '%s %s: %s' % (code, self.name, self.description)
+ code = self.code if self.code is not None else "???"
+ return "%s %s: %s" % (code, self.name, self.description)
def __repr__(self):
- code = self.code if self.code is not None else '???'
+ code = self.code if self.code is not None else "???"
return "<%s '%s: %s'>" % (self.__class__.__name__, code, self.name)
class BadRequest(HTTPException):
-
"""*400* `Bad Request`
Raise if the browser sends something to the application the application
or server cannot handle.
"""
+
code = 400
description = (
- 'The browser (or proxy) sent a request that this server could '
- 'not understand.'
+ "The browser (or proxy) sent a request that this server could "
+ "not understand."
)
class ClientDisconnected(BadRequest):
-
"""Internal exception that is raised if Werkzeug detects a disconnected
client. Since the client is already gone at that point attempting to
send the error message to the client might not work and might ultimately
@@ -218,7 +220,6 @@ class ClientDisconnected(BadRequest):
class SecurityError(BadRequest):
-
"""Raised if something triggers a security error. This is otherwise
exactly like a bad request error.
@@ -227,7 +228,6 @@ class SecurityError(BadRequest):
class BadHost(BadRequest):
-
"""Raised if the submitted host is badly formatted.
.. versionadded:: 0.11.2
@@ -235,7 +235,6 @@ class BadHost(BadRequest):
class Unauthorized(HTTPException):
-
"""*401* ``Unauthorized``
Raise if the user is not authorized to access a resource.
@@ -252,12 +251,13 @@ class Unauthorized(HTTPException):
:param description: Override the default message used for the body
of the response.
"""
+
code = 401
description = (
- 'The server could not verify that you are authorized to access'
- ' the URL requested. You either supplied the wrong credentials'
+ "The server could not verify that you are authorized to access"
+ " the URL requested. You either supplied the wrong credentials"
" (e.g. a bad password), or your browser doesn't understand"
- ' how to supply the credentials required.'
+ " how to supply the credentials required."
)
def __init__(self, www_authenticate=None, description=None):
@@ -269,43 +269,41 @@ class Unauthorized(HTTPException):
def get_headers(self, environ=None):
headers = HTTPException.get_headers(self, environ)
if self.www_authenticate:
- headers.append((
- 'WWW-Authenticate',
- ', '.join([str(x) for x in self.www_authenticate])
- ))
+ headers.append(
+ ("WWW-Authenticate", ", ".join([str(x) for x in self.www_authenticate]))
+ )
return headers
class Forbidden(HTTPException):
-
"""*403* `Forbidden`
Raise if the user doesn't have the permission for the requested resource
but was authenticated.
"""
+
code = 403
description = (
- 'You don\'t have the permission to access the requested resource. '
- 'It is either read-protected or not readable by the server.'
+ "You don't have the permission to access the requested"
+ " resource. It is either read-protected or not readable by the"
+ " server."
)
class NotFound(HTTPException):
-
"""*404* `Not Found`
Raise if a resource does not exist and never existed.
"""
+
code = 404
description = (
- 'The requested URL was not found on the server. '
- 'If you entered the URL manually please check your spelling and '
- 'try again.'
+ "The requested URL was not found on the server. If you entered"
+ " the URL manually please check your spelling and try again."
)
class MethodNotAllowed(HTTPException):
-
"""*405* `Method Not Allowed`
Raise if the server used a method the resource does not handle. For
@@ -315,8 +313,9 @@ class MethodNotAllowed(HTTPException):
Strictly speaking the response would be invalid if you don't provide valid
methods in the header which you can do with that list.
"""
+
code = 405
- description = 'The method is not allowed for the requested URL.'
+ description = "The method is not allowed for the requested URL."
def __init__(self, valid_methods=None, description=None):
"""Takes an optional list of valid http methods
@@ -327,42 +326,41 @@ class MethodNotAllowed(HTTPException):
def get_headers(self, environ=None):
headers = HTTPException.get_headers(self, environ)
if self.valid_methods:
- headers.append(('Allow', ', '.join(self.valid_methods)))
+ headers.append(("Allow", ", ".join(self.valid_methods)))
return headers
class NotAcceptable(HTTPException):
-
"""*406* `Not Acceptable`
Raise if the server can't return any content conforming to the
`Accept` headers of the client.
"""
+
code = 406
description = (
- 'The resource identified by the request is only capable of '
- 'generating response entities which have content characteristics '
- 'not acceptable according to the accept headers sent in the '
- 'request.'
+ "The resource identified by the request is only capable of"
+ " generating response entities which have content"
+ " characteristics not acceptable according to the accept"
+ " headers sent in the request."
)
class RequestTimeout(HTTPException):
-
"""*408* `Request Timeout`
Raise to signalize a timeout.
"""
+
code = 408
description = (
- 'The server closed the network connection because the browser '
- 'didn\'t finish the request within the specified time.'
+ "The server closed the network connection because the browser"
+ " didn't finish the request within the specified time."
)
class Conflict(HTTPException):
-
"""*409* `Conflict`
Raise to signal that a request cannot be completed because it conflicts
@@ -370,107 +368,103 @@ class Conflict(HTTPException):
.. versionadded:: 0.7
"""
+
code = 409
description = (
- 'A conflict happened while processing the request. The resource '
- 'might have been modified while the request was being processed.'
+ "A conflict happened while processing the request. The"
+ " resource might have been modified while the request was being"
+ " processed."
)
class Gone(HTTPException):
-
"""*410* `Gone`
Raise if a resource existed previously and went away without new location.
"""
+
code = 410
description = (
- 'The requested URL is no longer available on this server and there '
- 'is no forwarding address. If you followed a link from a foreign '
- 'page, please contact the author of this page.'
+ "The requested URL is no longer available on this server and"
+ " there is no forwarding address. If you followed a link from a"
+ " foreign page, please contact the author of this page."
)
class LengthRequired(HTTPException):
-
"""*411* `Length Required`
Raise if the browser submitted data but no ``Content-Length`` header which
is required for the kind of processing the server does.
"""
+
code = 411
description = (
- 'A request with this method requires a valid <code>Content-'
- 'Length</code> header.'
+ "A request with this method requires a valid <code>Content-"
+ "Length</code> header."
)
class PreconditionFailed(HTTPException):
-
"""*412* `Precondition Failed`
Status code used in combination with ``If-Match``, ``If-None-Match``, or
``If-Unmodified-Since``.
"""
+
code = 412
description = (
- 'The precondition on the request for the URL failed positive '
- 'evaluation.'
+ "The precondition on the request for the URL failed positive evaluation."
)
class RequestEntityTooLarge(HTTPException):
-
"""*413* `Request Entity Too Large`
The status code one should return if the data submitted exceeded a given
limit.
"""
+
code = 413
- description = (
- 'The data value transmitted exceeds the capacity limit.'
- )
+ description = "The data value transmitted exceeds the capacity limit."
class RequestURITooLarge(HTTPException):
-
"""*414* `Request URI Too Large`
Like *413* but for too long URLs.
"""
+
code = 414
description = (
- 'The length of the requested URL exceeds the capacity limit '
- 'for this server. The request cannot be processed.'
+ "The length of the requested URL exceeds the capacity limit for"
+ " this server. The request cannot be processed."
)
class UnsupportedMediaType(HTTPException):
-
"""*415* `Unsupported Media Type`
The status code returned if the server is unable to handle the media type
the client transmitted.
"""
+
code = 415
description = (
- 'The server does not support the media type transmitted in '
- 'the request.'
+ "The server does not support the media type transmitted in the request."
)
class RequestedRangeNotSatisfiable(HTTPException):
-
"""*416* `Requested Range Not Satisfiable`
The client asked for an invalid part of the file.
.. versionadded:: 0.7
"""
+
code = 416
- description = (
- 'The server cannot provide the requested range.'
- )
+ description = "The server cannot provide the requested range."
def __init__(self, length=None, units="bytes", description=None):
"""Takes an optional `Content-Range` header value based on ``length``
@@ -483,27 +477,23 @@ class RequestedRangeNotSatisfiable(HTTPException):
def get_headers(self, environ=None):
headers = HTTPException.get_headers(self, environ)
if self.length is not None:
- headers.append(
- ('Content-Range', '%s */%d' % (self.units, self.length)))
+ headers.append(("Content-Range", "%s */%d" % (self.units, self.length)))
return headers
class ExpectationFailed(HTTPException):
-
"""*417* `Expectation Failed`
The server cannot meet the requirements of the Expect request-header.
.. versionadded:: 0.7
"""
+
code = 417
- description = (
- 'The server could not meet the requirements of the Expect header'
- )
+ description = "The server could not meet the requirements of the Expect header"
class ImATeapot(HTTPException):
-
"""*418* `I'm a teapot`
The server should return this if it is a teapot and someone attempted
@@ -511,54 +501,51 @@ class ImATeapot(HTTPException):
.. versionadded:: 0.7
"""
+
code = 418
- description = (
- 'This server is a teapot, not a coffee machine'
- )
+ description = "This server is a teapot, not a coffee machine"
class UnprocessableEntity(HTTPException):
-
"""*422* `Unprocessable Entity`
Used if the request is well formed, but the instructions are otherwise
incorrect.
"""
+
code = 422
description = (
- 'The request was well-formed but was unable to be followed '
- 'due to semantic errors.'
+ "The request was well-formed but was unable to be followed due"
+ " to semantic errors."
)
class Locked(HTTPException):
-
"""*423* `Locked`
Used if the resource that is being accessed is locked.
"""
+
code = 423
- description = (
- 'The resource that is being accessed is locked.'
- )
+ description = "The resource that is being accessed is locked."
class FailedDependency(HTTPException):
-
"""*424* `Failed Dependency`
Used if the method could not be performed on the resource
because the requested action depended on another action and that action failed.
"""
+
code = 424
description = (
- 'The method could not be performed on the resource because the requested action '
- 'depended on another action and that action failed.'
+ "The method could not be performed on the resource because the"
+ " requested action depended on another action and that action"
+ " failed."
)
class PreconditionRequired(HTTPException):
-
"""*428* `Precondition Required`
The server requires this request to be conditional, typically to prevent
@@ -569,15 +556,15 @@ class PreconditionRequired(HTTPException):
server ensures that each client has at least seen the previous revision of
the resource.
"""
+
code = 428
description = (
- 'This request is required to be conditional; try using "If-Match" '
- 'or "If-Unmodified-Since".'
+ "This request is required to be conditional; try using"
+ ' "If-Match" or "If-Unmodified-Since".'
)
class TooManyRequests(HTTPException):
-
"""*429* `Too Many Requests`
The server is limiting the rate at which this user receives responses, and
@@ -586,129 +573,117 @@ class TooManyRequests(HTTPException):
"Retry-After" header to indicate how long the user should wait before
retrying.
"""
+
code = 429
- description = (
- 'This user has exceeded an allotted request count. Try again later.'
- )
+ description = "This user has exceeded an allotted request count. Try again later."
class RequestHeaderFieldsTooLarge(HTTPException):
-
"""*431* `Request Header Fields Too Large`
The server refuses to process the request because the header fields are too
large. One or more individual fields may be too large, or the set of all
headers is too large.
"""
+
code = 431
- description = (
- 'One or more header fields exceeds the maximum size.'
- )
+ description = "One or more header fields exceeds the maximum size."
class UnavailableForLegalReasons(HTTPException):
-
"""*451* `Unavailable For Legal Reasons`
This status code indicates that the server is denying access to the
resource as a consequence of a legal demand.
"""
+
code = 451
- description = (
- 'Unavailable for legal reasons.'
- )
+ description = "Unavailable for legal reasons."
class InternalServerError(HTTPException):
-
"""*500* `Internal Server Error`
Raise if an internal server error occurred. This is a good fallback if an
unknown error occurred in the dispatcher.
"""
+
code = 500
description = (
- 'The server encountered an internal error and was unable to '
- 'complete your request. Either the server is overloaded or there '
- 'is an error in the application.'
+ "The server encountered an internal error and was unable to"
+ " complete your request. Either the server is overloaded or"
+ " there is an error in the application."
)
class NotImplemented(HTTPException):
-
"""*501* `Not Implemented`
Raise if the application does not support the action requested by the
browser.
"""
+
code = 501
- description = (
- 'The server does not support the action requested by the '
- 'browser.'
- )
+ description = "The server does not support the action requested by the browser."
class BadGateway(HTTPException):
-
"""*502* `Bad Gateway`
If you do proxying in your application you should return this status code
if you received an invalid response from the upstream server it accessed
in attempting to fulfill the request.
"""
+
code = 502
description = (
- 'The proxy server received an invalid response from an upstream '
- 'server.'
+ "The proxy server received an invalid response from an upstream server."
)
class ServiceUnavailable(HTTPException):
-
"""*503* `Service Unavailable`
Status code you should return if a service is temporarily unavailable.
"""
+
code = 503
description = (
- 'The server is temporarily unable to service your request due to '
- 'maintenance downtime or capacity problems. Please try again '
- 'later.'
+ "The server is temporarily unable to service your request due"
+ " to maintenance downtime or capacity problems. Please try"
+ " again later."
)
class GatewayTimeout(HTTPException):
-
"""*504* `Gateway Timeout`
Status code you should return if a connection to an upstream server
times out.
"""
+
code = 504
- description = (
- 'The connection to an upstream server timed out.'
- )
+ description = "The connection to an upstream server timed out."
class HTTPVersionNotSupported(HTTPException):
-
"""*505* `HTTP Version Not Supported`
The server does not support the HTTP protocol version used in the request.
"""
+
code = 505
description = (
- 'The server does not support the HTTP protocol version used in the '
- 'request.'
+ "The server does not support the HTTP protocol version used in the request."
)
default_exceptions = {}
-__all__ = ['HTTPException']
+__all__ = ["HTTPException"]
def _find_exceptions():
- for name, obj in iteritems(globals()):
+ for _name, obj in iteritems(globals()):
try:
is_http_exception = issubclass(obj, HTTPException)
except TypeError:
@@ -720,14 +695,14 @@ def _find_exceptions():
if old_obj is not None and issubclass(obj, old_obj):
continue
default_exceptions[obj.code] = obj
+
+
_find_exceptions()
del _find_exceptions
class Aborter(object):
-
- """
- When passed a dict of code -> exception items it can be used as
+ """When passed a dict of code -> exception items it can be used as
callable that raises exceptions. If the first argument to the
callable is an integer it will be looked up in the mapping, if it's
a WSGI application it will be raised in a proxy exception.
@@ -746,13 +721,12 @@ class Aborter(object):
if not args and not kwargs and not isinstance(code, integer_types):
raise HTTPException(response=code)
if code not in self.mapping:
- raise LookupError('no exception for %r' % code)
+ raise LookupError("no exception for %r" % code)
raise self.mapping[code](*args, **kwargs)
def abort(status, *args, **kwargs):
- '''
- Raises an :py:exc:`HTTPException` for the given status code or WSGI
+ """Raises an :py:exc:`HTTPException` for the given status code or WSGI
application::
abort(404) # 404 Not Found
@@ -766,9 +740,10 @@ def abort(status, *args, **kwargs):
abort(404)
abort(Response('Hello World'))
- '''
+ """
return _aborter(status, *args, **kwargs)
+
_aborter = Aborter()
@@ -777,5 +752,5 @@ _aborter = Aborter()
BadRequestKeyError = BadRequest.wrap(KeyError)
# imported here because of circular dependencies of werkzeug.utils
-from werkzeug.utils import escape
-from werkzeug.http import HTTP_STATUS_CODES
+from .http import HTTP_STATUS_CODES
+from .utils import escape
diff --git a/src/werkzeug/filesystem.py b/src/werkzeug/filesystem.py
index 927f37ba..d016caea 100644
--- a/src/werkzeug/filesystem.py
+++ b/src/werkzeug/filesystem.py
@@ -8,41 +8,39 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-
import codecs
import sys
import warnings
# We do not trust traditional unixes.
-has_likely_buggy_unicode_filesystem = \
- sys.platform.startswith('linux') or 'bsd' in sys.platform
+has_likely_buggy_unicode_filesystem = (
+ sys.platform.startswith("linux") or "bsd" in sys.platform
+)
def _is_ascii_encoding(encoding):
- """
- Given an encoding this figures out if the encoding is actually ASCII (which
+ """Given an encoding this figures out if the encoding is actually ASCII (which
is something we don't actually want in most cases). This is necessary
because ASCII comes under many names such as ANSI_X3.4-1968.
"""
if encoding is None:
return False
try:
- return codecs.lookup(encoding).name == 'ascii'
+ return codecs.lookup(encoding).name == "ascii"
except LookupError:
return False
class BrokenFilesystemWarning(RuntimeWarning, UnicodeWarning):
- '''The warning used by Werkzeug to signal a broken filesystem. Will only be
- used once per runtime.'''
+ """The warning used by Werkzeug to signal a broken filesystem. Will only be
+ used once per runtime."""
_warned_about_filesystem_encoding = False
def get_filesystem_encoding():
- """
- Returns the filesystem encoding that should be used. Note that this is
+ """Returns the filesystem encoding that should be used. Note that this is
different from the Python understanding of the filesystem encoding which
might be deeply flawed. Do not use this value against Python's unicode APIs
because it might be different. See :ref:`filesystem-encoding` for the exact
@@ -54,13 +52,13 @@ def get_filesystem_encoding():
"""
global _warned_about_filesystem_encoding
rv = sys.getfilesystemencoding()
- if has_likely_buggy_unicode_filesystem and not rv \
- or _is_ascii_encoding(rv):
+ if has_likely_buggy_unicode_filesystem and not rv or _is_ascii_encoding(rv):
if not _warned_about_filesystem_encoding:
warnings.warn(
- 'Detected a misconfigured UNIX filesystem: Will use UTF-8 as '
- 'filesystem encoding instead of {0!r}'.format(rv),
- BrokenFilesystemWarning)
+ "Detected a misconfigured UNIX filesystem: Will use"
+ " UTF-8 as filesystem encoding instead of {0!r}".format(rv),
+ BrokenFilesystemWarning,
+ )
_warned_about_filesystem_encoding = True
- return 'utf-8'
+ return "utf-8"
return rv
diff --git a/src/werkzeug/formparser.py b/src/werkzeug/formparser.py
index 370c0751..e626935d 100644
--- a/src/werkzeug/formparser.py
+++ b/src/werkzeug/formparser.py
@@ -9,8 +9,24 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-import re
import codecs
+import re
+from functools import update_wrapper
+from itertools import chain
+from itertools import repeat
+from itertools import tee
+
+from ._compat import BytesIO
+from ._compat import text_type
+from ._compat import to_native
+from .datastructures import FileStorage
+from .datastructures import Headers
+from .datastructures import MultiDict
+from .http import parse_options_header
+from .urls import url_decode_stream
+from .wsgi import get_content_length
+from .wsgi import get_input_stream
+from .wsgi import make_line_iter
# there are some platforms where SpooledTemporaryFile is not available.
# In that case we need to provide a fallback.
@@ -18,45 +34,43 @@ try:
from tempfile import SpooledTemporaryFile
except ImportError:
from tempfile import TemporaryFile
- SpooledTemporaryFile = None
-
-from itertools import chain, repeat, tee
-from functools import update_wrapper
-from werkzeug._compat import to_native, text_type, BytesIO
-from werkzeug.urls import url_decode_stream
-from werkzeug.wsgi import make_line_iter, \
- get_input_stream, get_content_length
-from werkzeug.datastructures import Headers, FileStorage, MultiDict
-from werkzeug.http import parse_options_header
+ SpooledTemporaryFile = None
#: an iterator that yields empty strings
-_empty_string_iter = repeat('')
+_empty_string_iter = repeat("")
#: a regular expression for multipart boundaries
-_multipart_boundary_re = re.compile('^[ -~]{0,200}[!-~]$')
+_multipart_boundary_re = re.compile("^[ -~]{0,200}[!-~]$")
#: supported http encodings that are also available in python we support
#: for multipart messages.
-_supported_multipart_encodings = frozenset(['base64', 'quoted-printable'])
+_supported_multipart_encodings = frozenset(["base64", "quoted-printable"])
-def default_stream_factory(total_content_length, filename, content_type,
- content_length=None):
+def default_stream_factory(
+ total_content_length, filename, content_type, content_length=None
+):
"""The stream factory that is used per default."""
max_size = 1024 * 500
if SpooledTemporaryFile is not None:
- return SpooledTemporaryFile(max_size=max_size, mode='wb+')
+ return SpooledTemporaryFile(max_size=max_size, mode="wb+")
if total_content_length is None or total_content_length > max_size:
- return TemporaryFile('wb+')
+ return TemporaryFile("wb+")
return BytesIO()
-def parse_form_data(environ, stream_factory=None, charset='utf-8',
- errors='replace', max_form_memory_size=None,
- max_content_length=None, cls=None,
- silent=True):
+def parse_form_data(
+ environ,
+ stream_factory=None,
+ charset="utf-8",
+ errors="replace",
+ max_form_memory_size=None,
+ max_content_length=None,
+ cls=None,
+ silent=True,
+):
"""Parse the form data in the environ and return it as tuple in the form
``(stream, form, files)``. You should only call this method if the
transport method is `POST`, `PUT`, or `PATCH`.
@@ -97,9 +111,15 @@ def parse_form_data(environ, stream_factory=None, charset='utf-8',
:param silent: If set to False parsing errors will not be caught.
:return: A tuple in the form ``(stream, form, files)``.
"""
- return FormDataParser(stream_factory, charset, errors,
- max_form_memory_size, max_content_length,
- cls, silent).parse_from_environ(environ)
+ return FormDataParser(
+ stream_factory,
+ charset,
+ errors,
+ max_form_memory_size,
+ max_content_length,
+ cls,
+ silent,
+ ).parse_from_environ(environ)
def exhaust_stream(f):
@@ -109,7 +129,7 @@ def exhaust_stream(f):
try:
return f(self, stream, *args, **kwargs)
finally:
- exhaust = getattr(stream, 'exhaust', None)
+ exhaust = getattr(stream, "exhaust", None)
if exhaust is not None:
exhaust()
else:
@@ -117,11 +137,11 @@ def exhaust_stream(f):
chunk = stream.read(1024 * 64)
if not chunk:
break
+
return update_wrapper(wrapper, f)
class FormDataParser(object):
-
"""This class implements parsing of form data for Werkzeug. By itself
it can parse multipart and url encoded form data. It can be subclassed
and extended but for most mimetypes it is a better idea to use the
@@ -149,10 +169,16 @@ class FormDataParser(object):
:param silent: If set to False parsing errors will not be caught.
"""
- def __init__(self, stream_factory=None, charset='utf-8',
- errors='replace', max_form_memory_size=None,
- max_content_length=None, cls=None,
- silent=True):
+ def __init__(
+ self,
+ stream_factory=None,
+ charset="utf-8",
+ errors="replace",
+ max_form_memory_size=None,
+ max_content_length=None,
+ cls=None,
+ silent=True,
+ ):
if stream_factory is None:
stream_factory = default_stream_factory
self.stream_factory = stream_factory
@@ -174,11 +200,10 @@ class FormDataParser(object):
:param environ: the WSGI environment to be used for parsing.
:return: A tuple in the form ``(stream, form, files)``.
"""
- content_type = environ.get('CONTENT_TYPE', '')
+ content_type = environ.get("CONTENT_TYPE", "")
content_length = get_content_length(environ)
mimetype, options = parse_options_header(content_type)
- return self.parse(get_input_stream(environ), mimetype,
- content_length, options)
+ return self.parse(get_input_stream(environ), mimetype, content_length, options)
def parse(self, stream, mimetype, content_length, options=None):
"""Parses the information from the given stream, mimetype,
@@ -191,9 +216,11 @@ class FormDataParser(object):
the multipart boundary for instance)
:return: A tuple in the form ``(stream, form, files)``.
"""
- if self.max_content_length is not None and \
- content_length is not None and \
- content_length > self.max_content_length:
+ if (
+ self.max_content_length is not None
+ and content_length is not None
+ and content_length > self.max_content_length
+ ):
raise exceptions.RequestEntityTooLarge()
if options is None:
options = {}
@@ -201,8 +228,7 @@ class FormDataParser(object):
parse_func = self.get_parse_func(mimetype, options)
if parse_func is not None:
try:
- return parse_func(self, stream, mimetype,
- content_length, options)
+ return parse_func(self, stream, mimetype, content_length, options)
except ValueError:
if not self.silent:
raise
@@ -211,32 +237,37 @@ class FormDataParser(object):
@exhaust_stream
def _parse_multipart(self, stream, mimetype, content_length, options):
- parser = MultiPartParser(self.stream_factory, self.charset, self.errors,
- max_form_memory_size=self.max_form_memory_size,
- cls=self.cls)
- boundary = options.get('boundary')
+ parser = MultiPartParser(
+ self.stream_factory,
+ self.charset,
+ self.errors,
+ max_form_memory_size=self.max_form_memory_size,
+ cls=self.cls,
+ )
+ boundary = options.get("boundary")
if boundary is None:
- raise ValueError('Missing boundary')
+ raise ValueError("Missing boundary")
if isinstance(boundary, text_type):
- boundary = boundary.encode('ascii')
+ boundary = boundary.encode("ascii")
form, files = parser.parse(stream, boundary, content_length)
return stream, form, files
@exhaust_stream
def _parse_urlencoded(self, stream, mimetype, content_length, options):
- if self.max_form_memory_size is not None and \
- content_length is not None and \
- content_length > self.max_form_memory_size:
+ if (
+ self.max_form_memory_size is not None
+ and content_length is not None
+ and content_length > self.max_form_memory_size
+ ):
raise exceptions.RequestEntityTooLarge()
- form = url_decode_stream(stream, self.charset,
- errors=self.errors, cls=self.cls)
+ form = url_decode_stream(stream, self.charset, errors=self.errors, cls=self.cls)
return stream, form, self.cls()
#: mapping of mimetypes to parsing functions
parse_functions = {
- 'multipart/form-data': _parse_multipart,
- 'application/x-www-form-urlencoded': _parse_urlencoded,
- 'application/x-url-encoded': _parse_urlencoded
+ "multipart/form-data": _parse_multipart,
+ "application/x-www-form-urlencoded": _parse_urlencoded,
+ "application/x-url-encoded": _parse_urlencoded,
}
@@ -249,9 +280,9 @@ def _line_parse(line):
"""Removes line ending characters and returns a tuple (`stripped_line`,
`is_terminated`).
"""
- if line[-2:] in ['\r\n', b'\r\n']:
+ if line[-2:] in ["\r\n", b"\r\n"]:
return line[:-2], True
- elif line[-1:] in ['\r', '\n', b'\r', b'\n']:
+ elif line[-1:] in ["\r", "\n", b"\r", b"\n"]:
return line[:-1], True
return line, False
@@ -270,14 +301,14 @@ def parse_multipart_headers(iterable):
line = to_native(line)
line, line_terminated = _line_parse(line)
if not line_terminated:
- raise ValueError('unexpected end of line in multipart header')
+ raise ValueError("unexpected end of line in multipart header")
if not line:
break
- elif line[0] in ' \t' and result:
+ elif line[0] in " \t" and result:
key, value = result[-1]
- result[-1] = (key, value + '\n ' + line[1:])
+ result[-1] = (key, value + "\n " + line[1:])
else:
- parts = line.split(':', 1)
+ parts = line.split(":", 1)
if len(parts) == 2:
result.append((parts[0].strip(), parts[1].strip()))
@@ -286,28 +317,36 @@ def parse_multipart_headers(iterable):
return Headers(result)
-_begin_form = 'begin_form'
-_begin_file = 'begin_file'
-_cont = 'cont'
-_end = 'end'
+_begin_form = "begin_form"
+_begin_file = "begin_file"
+_cont = "cont"
+_end = "end"
class MultiPartParser(object):
-
- def __init__(self, stream_factory=None, charset='utf-8', errors='replace',
- max_form_memory_size=None, cls=None, buffer_size=64 * 1024):
+ def __init__(
+ self,
+ stream_factory=None,
+ charset="utf-8",
+ errors="replace",
+ max_form_memory_size=None,
+ cls=None,
+ buffer_size=64 * 1024,
+ ):
self.charset = charset
self.errors = errors
self.max_form_memory_size = max_form_memory_size
- self.stream_factory = default_stream_factory if stream_factory is None else stream_factory
+ self.stream_factory = (
+ default_stream_factory if stream_factory is None else stream_factory
+ )
self.cls = MultiDict if cls is None else cls
# make sure the buffer size is divisible by four so that we can base64
# decode chunk by chunk
- assert buffer_size % 4 == 0, 'buffer size has to be divisible by 4'
+ assert buffer_size % 4 == 0, "buffer size has to be divisible by 4"
# also the buffer size has to be at least 1024 bytes long or long headers
# will freak out the system
- assert buffer_size >= 1024, 'buffer size has to be at least 1KB'
+ assert buffer_size >= 1024, "buffer size has to be at least 1KB"
self.buffer_size = buffer_size
@@ -316,8 +355,8 @@ class MultiPartParser(object):
uploaded. This function strips the full path if it thinks the
filename is Windows-like absolute.
"""
- if filename[1:3] == ':\\' or filename[:2] == '\\\\':
- return filename.split('\\')[-1]
+ if filename[1:3] == ":\\" or filename[:2] == "\\\\":
+ return filename.split("\\")[-1]
return filename
def _find_terminator(self, iterator):
@@ -331,36 +370,39 @@ class MultiPartParser(object):
line = line.strip()
if line:
return line
- return b''
+ return b""
def fail(self, message):
raise ValueError(message)
def get_part_encoding(self, headers):
- transfer_encoding = headers.get('content-transfer-encoding')
- if transfer_encoding is not None and \
- transfer_encoding in _supported_multipart_encodings:
+ transfer_encoding = headers.get("content-transfer-encoding")
+ if (
+ transfer_encoding is not None
+ and transfer_encoding in _supported_multipart_encodings
+ ):
return transfer_encoding
def get_part_charset(self, headers):
# Figure out input charset for current part
- content_type = headers.get('content-type')
+ content_type = headers.get("content-type")
if content_type:
mimetype, ct_params = parse_options_header(content_type)
- return ct_params.get('charset', self.charset)
+ return ct_params.get("charset", self.charset)
return self.charset
def start_file_streaming(self, filename, headers, total_content_length):
if isinstance(filename, bytes):
filename = filename.decode(self.charset, self.errors)
filename = self._fix_ie_filename(filename)
- content_type = headers.get('content-type')
+ content_type = headers.get("content-type")
try:
- content_length = int(headers['content-length'])
+ content_length = int(headers["content-length"])
except (KeyError, ValueError):
content_length = 0
- container = self.stream_factory(total_content_length, content_type,
- filename, content_length)
+ container = self.stream_factory(
+ total_content_length, content_type, filename, content_length
+ )
return filename, container
def in_memory_threshold_reached(self, bytes):
@@ -368,15 +410,15 @@ class MultiPartParser(object):
def validate_boundary(self, boundary):
if not boundary:
- self.fail('Missing boundary')
+ self.fail("Missing boundary")
if not is_valid_multipart_boundary(boundary):
- self.fail('Invalid boundary: %s' % boundary)
+ self.fail("Invalid boundary: %s" % boundary)
if len(boundary) > self.buffer_size: # pragma: no cover
# this should never happen because we check for a minimum size
# of 1024 and boundaries may not be longer than 200. The only
# situation when this happens is for non debug builds where
# the assert is skipped.
- self.fail('Boundary longer than buffer size')
+ self.fail("Boundary longer than buffer size")
def parse_lines(self, file, boundary, content_length, cap_at_buffer=True):
"""Generate parts of
@@ -389,31 +431,36 @@ class MultiPartParser(object):
parts = ( begin_form cont* end |
begin_file cont* end )*
"""
- next_part = b'--' + boundary
- last_part = next_part + b'--'
-
- iterator = chain(make_line_iter(file, limit=content_length,
- buffer_size=self.buffer_size,
- cap_at_buffer=cap_at_buffer),
- _empty_string_iter)
+ next_part = b"--" + boundary
+ last_part = next_part + b"--"
+
+ iterator = chain(
+ make_line_iter(
+ file,
+ limit=content_length,
+ buffer_size=self.buffer_size,
+ cap_at_buffer=cap_at_buffer,
+ ),
+ _empty_string_iter,
+ )
terminator = self._find_terminator(iterator)
if terminator == last_part:
return
elif terminator != next_part:
- self.fail('Expected boundary at start of multipart data')
+ self.fail("Expected boundary at start of multipart data")
while terminator != last_part:
headers = parse_multipart_headers(iterator)
- disposition = headers.get('content-disposition')
+ disposition = headers.get("content-disposition")
if disposition is None:
- self.fail('Missing Content-Disposition header')
+ self.fail("Missing Content-Disposition header")
disposition, extra = parse_options_header(disposition)
transfer_encoding = self.get_part_encoding(headers)
- name = extra.get('name')
- filename = extra.get('filename')
+ name = extra.get("name")
+ filename = extra.get("filename")
# if no content type is given we stream into memory. A list is
# used as a temporary container.
@@ -425,29 +472,29 @@ class MultiPartParser(object):
else:
yield _begin_file, (headers, name, filename)
- buf = b''
+ buf = b""
for line in iterator:
if not line:
- self.fail('unexpected end of stream')
+ self.fail("unexpected end of stream")
- if line[:2] == b'--':
+ if line[:2] == b"--":
terminator = line.rstrip()
if terminator in (next_part, last_part):
break
if transfer_encoding is not None:
- if transfer_encoding == 'base64':
- transfer_encoding = 'base64_codec'
+ if transfer_encoding == "base64":
+ transfer_encoding = "base64_codec"
try:
line = codecs.decode(line, transfer_encoding)
except Exception:
- self.fail('could not decode transfer encoded chunk')
+ self.fail("could not decode transfer encoded chunk")
# we have something in the buffer from the last iteration.
# this is usually a newline delimiter.
if buf:
yield _cont, buf
- buf = b''
+ buf = b""
# If the line ends with windows CRLF we write everything except
# the last two bytes. In all other cases however we write
@@ -458,8 +505,8 @@ class MultiPartParser(object):
# truncate the stream. However we do have to make sure that
# if something else than a newline is in there we write it
# out.
- if line[-2:] == b'\r\n':
- buf = b'\r\n'
+ if line[-2:] == b"\r\n":
+ buf = b"\r\n"
cutoff = -2
else:
buf = line[-1:]
@@ -467,12 +514,12 @@ class MultiPartParser(object):
yield _cont, line[:cutoff]
else: # pragma: no cover
- raise ValueError('unexpected end of part')
+ raise ValueError("unexpected end of part")
# if we have a leftover in the buffer that is not a newline
# character we have to flush it, otherwise we will chop of
# certain values.
- if buf not in (b'', b'\r', b'\n', b'\r\n'):
+ if buf not in (b"", b"\r", b"\n", b"\r\n"):
yield _cont, buf
yield _end, None
@@ -489,7 +536,8 @@ class MultiPartParser(object):
is_file = True
guard_memory = False
filename, container = self.start_file_streaming(
- filename, headers, content_length)
+ filename, headers, content_length
+ )
_write = container.write
elif ellt == _begin_form:
@@ -512,21 +560,24 @@ class MultiPartParser(object):
elif ellt == _end:
if is_file:
container.seek(0)
- yield ('file',
- (name, FileStorage(container, filename, name,
- headers=headers)))
+ yield (
+ "file",
+ (name, FileStorage(container, filename, name, headers=headers)),
+ )
else:
part_charset = self.get_part_charset(headers)
- yield ('form',
- (name, b''.join(container).decode(
- part_charset, self.errors)))
+ yield (
+ "form",
+ (name, b"".join(container).decode(part_charset, self.errors)),
+ )
def parse(self, file, boundary, content_length):
formstream, filestream = tee(
- self.parse_parts(file, boundary, content_length), 2)
- form = (p[1] for p in formstream if p[0] == 'form')
- files = (p[1] for p in filestream if p[0] == 'file')
+ self.parse_parts(file, boundary, content_length), 2
+ )
+ form = (p[1] for p in formstream if p[0] == "form")
+ files = (p[1] for p in filestream if p[0] == "file")
return self.cls(form), self.cls(files)
-from werkzeug import exceptions
+from . import exceptions
diff --git a/src/werkzeug/http.py b/src/werkzeug/http.py
index 9e18ea97..af320075 100644
--- a/src/werkzeug/http.py
+++ b/src/werkzeug/http.py
@@ -16,54 +16,68 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
+import base64
import re
import warnings
-from time import time, gmtime
+from datetime import datetime
+from datetime import timedelta
+from hashlib import md5
+from time import gmtime
+from time import time
+
+from ._compat import integer_types
+from ._compat import iteritems
+from ._compat import PY2
+from ._compat import string_types
+from ._compat import text_type
+from ._compat import to_bytes
+from ._compat import to_unicode
+from ._compat import try_coerce_native
+from ._internal import _cookie_parse_impl
+from ._internal import _cookie_quote
+from ._internal import _make_cookie_domain
+
try:
from email.utils import parsedate_tz
-except ImportError: # pragma: no cover
+except ImportError:
from email.Utils import parsedate_tz
+
try:
from urllib.request import parse_http_list as _parse_list_header
from urllib.parse import unquote_to_bytes as _unquote
-except ImportError: # pragma: no cover
- from urllib2 import parse_http_list as _parse_list_header, \
- unquote as _unquote
-from datetime import datetime, timedelta
-from hashlib import md5
-import base64
+except ImportError:
+ from urllib2 import parse_http_list as _parse_list_header
+ from urllib2 import unquote as _unquote
-from werkzeug._internal import _cookie_quote, _make_cookie_domain, \
- _cookie_parse_impl
-from werkzeug._compat import to_unicode, iteritems, text_type, \
- string_types, try_coerce_native, to_bytes, PY2, \
- integer_types
-
-
-_cookie_charset = 'latin1'
-_basic_auth_charset = 'utf-8'
+_cookie_charset = "latin1"
+_basic_auth_charset = "utf-8"
# for explanation of "media-range", etc. see Sections 5.3.{1,2} of RFC 7231
_accept_re = re.compile(
- r'''( # media-range capturing-parenthesis
- [^\s;,]+ # type/subtype
- (?:[ \t]*;[ \t]* # ";"
- (?: # parameter non-capturing-parenthesis
- [^\s;,q][^\s;,]* # token that doesn't start with "q"
- | # or
- q[^\s;,=][^\s;,]* # token that is more than just "q"
- )
- )* # zero or more parameters
- ) # end of media-range
- (?:[ \t]*;[ \t]*q= # weight is a "q" parameter
- (\d*(?:\.\d+)?) # qvalue capturing-parentheses
- [^,]* # "extension" accept params: who cares?
- )? # accept params are optional
- ''', re.VERBOSE)
-_token_chars = frozenset("!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
- '^_`abcdefghijklmnopqrstuvwxyz|~')
+ r"""
+ ( # media-range capturing-parenthesis
+ [^\s;,]+ # type/subtype
+ (?:[ \t]*;[ \t]* # ";"
+ (?: # parameter non-capturing-parenthesis
+ [^\s;,q][^\s;,]* # token that doesn't start with "q"
+ | # or
+ q[^\s;,=][^\s;,]* # token that is more than just "q"
+ )
+ )* # zero or more parameters
+ ) # end of media-range
+ (?:[ \t]*;[ \t]*q= # weight is a "q" parameter
+ (\d*(?:\.\d+)?) # qvalue capturing-parentheses
+ [^,]* # "extension" accept params: who cares?
+ )? # accept params are optional
+ """,
+ re.VERBOSE,
+)
+_token_chars = frozenset(
+ "!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ^_`abcdefghijklmnopqrstuvwxyz|~"
+)
_etag_re = re.compile(r'([Ww]/)?(?:"(.*?)"|(.*?))(?:\s*,\s*|$)')
-_unsafe_header_chars = set('()<>@,;:\"/[]?={} \t')
-_option_header_piece_re = re.compile(r'''
+_unsafe_header_chars = set('()<>@,;:"/[]?={} \t')
+_option_header_piece_re = re.compile(
+ r"""
;\s*,?\s* # newlines were replaced with commas
(?P<key>
"[^"\\]*(?:\\.[^"\\]*)*" # quoted string
@@ -89,100 +103,116 @@ _option_header_piece_re = re.compile(r'''
)?
)?
\s*
-''', flags=re.VERBOSE)
-_option_header_start_mime_type = re.compile(r',\s*([^;,\s]+)([;,]\s*.+)?')
-
-_entity_headers = frozenset([
- 'allow', 'content-encoding', 'content-language', 'content-length',
- 'content-location', 'content-md5', 'content-range', 'content-type',
- 'expires', 'last-modified'
-])
-_hop_by_hop_headers = frozenset([
- 'connection', 'keep-alive', 'proxy-authenticate',
- 'proxy-authorization', 'te', 'trailer', 'transfer-encoding',
- 'upgrade'
-])
+ """,
+ flags=re.VERBOSE,
+)
+_option_header_start_mime_type = re.compile(r",\s*([^;,\s]+)([;,]\s*.+)?")
+
+_entity_headers = frozenset(
+ [
+ "allow",
+ "content-encoding",
+ "content-language",
+ "content-length",
+ "content-location",
+ "content-md5",
+ "content-range",
+ "content-type",
+ "expires",
+ "last-modified",
+ ]
+)
+_hop_by_hop_headers = frozenset(
+ [
+ "connection",
+ "keep-alive",
+ "proxy-authenticate",
+ "proxy-authorization",
+ "te",
+ "trailer",
+ "transfer-encoding",
+ "upgrade",
+ ]
+)
HTTP_STATUS_CODES = {
- 100: 'Continue',
- 101: 'Switching Protocols',
- 102: 'Processing',
- 200: 'OK',
- 201: 'Created',
- 202: 'Accepted',
- 203: 'Non Authoritative Information',
- 204: 'No Content',
- 205: 'Reset Content',
- 206: 'Partial Content',
- 207: 'Multi Status',
- 226: 'IM Used', # see RFC 3229
- 300: 'Multiple Choices',
- 301: 'Moved Permanently',
- 302: 'Found',
- 303: 'See Other',
- 304: 'Not Modified',
- 305: 'Use Proxy',
- 307: 'Temporary Redirect',
- 308: 'Permanent Redirect',
- 400: 'Bad Request',
- 401: 'Unauthorized',
- 402: 'Payment Required', # unused
- 403: 'Forbidden',
- 404: 'Not Found',
- 405: 'Method Not Allowed',
- 406: 'Not Acceptable',
- 407: 'Proxy Authentication Required',
- 408: 'Request Timeout',
- 409: 'Conflict',
- 410: 'Gone',
- 411: 'Length Required',
- 412: 'Precondition Failed',
- 413: 'Request Entity Too Large',
- 414: 'Request URI Too Long',
- 415: 'Unsupported Media Type',
- 416: 'Requested Range Not Satisfiable',
- 417: 'Expectation Failed',
- 418: 'I\'m a teapot', # see RFC 2324
- 421: 'Misdirected Request', # see RFC 7540
- 422: 'Unprocessable Entity',
- 423: 'Locked',
- 424: 'Failed Dependency',
- 426: 'Upgrade Required',
- 428: 'Precondition Required', # see RFC 6585
- 429: 'Too Many Requests',
- 431: 'Request Header Fields Too Large',
- 449: 'Retry With', # proprietary MS extension
- 451: 'Unavailable For Legal Reasons',
- 500: 'Internal Server Error',
- 501: 'Not Implemented',
- 502: 'Bad Gateway',
- 503: 'Service Unavailable',
- 504: 'Gateway Timeout',
- 505: 'HTTP Version Not Supported',
- 507: 'Insufficient Storage',
- 510: 'Not Extended'
+ 100: "Continue",
+ 101: "Switching Protocols",
+ 102: "Processing",
+ 200: "OK",
+ 201: "Created",
+ 202: "Accepted",
+ 203: "Non Authoritative Information",
+ 204: "No Content",
+ 205: "Reset Content",
+ 206: "Partial Content",
+ 207: "Multi Status",
+ 226: "IM Used", # see RFC 3229
+ 300: "Multiple Choices",
+ 301: "Moved Permanently",
+ 302: "Found",
+ 303: "See Other",
+ 304: "Not Modified",
+ 305: "Use Proxy",
+ 307: "Temporary Redirect",
+ 308: "Permanent Redirect",
+ 400: "Bad Request",
+ 401: "Unauthorized",
+ 402: "Payment Required", # unused
+ 403: "Forbidden",
+ 404: "Not Found",
+ 405: "Method Not Allowed",
+ 406: "Not Acceptable",
+ 407: "Proxy Authentication Required",
+ 408: "Request Timeout",
+ 409: "Conflict",
+ 410: "Gone",
+ 411: "Length Required",
+ 412: "Precondition Failed",
+ 413: "Request Entity Too Large",
+ 414: "Request URI Too Long",
+ 415: "Unsupported Media Type",
+ 416: "Requested Range Not Satisfiable",
+ 417: "Expectation Failed",
+ 418: "I'm a teapot", # see RFC 2324
+ 421: "Misdirected Request", # see RFC 7540
+ 422: "Unprocessable Entity",
+ 423: "Locked",
+ 424: "Failed Dependency",
+ 426: "Upgrade Required",
+ 428: "Precondition Required", # see RFC 6585
+ 429: "Too Many Requests",
+ 431: "Request Header Fields Too Large",
+ 449: "Retry With", # proprietary MS extension
+ 451: "Unavailable For Legal Reasons",
+ 500: "Internal Server Error",
+ 501: "Not Implemented",
+ 502: "Bad Gateway",
+ 503: "Service Unavailable",
+ 504: "Gateway Timeout",
+ 505: "HTTP Version Not Supported",
+ 507: "Insufficient Storage",
+ 510: "Not Extended",
}
def wsgi_to_bytes(data):
- """coerce wsgi unicode represented bytes to real ones
-
- """
+ """coerce wsgi unicode represented bytes to real ones"""
if isinstance(data, bytes):
return data
- return data.encode('latin1') # XXX: utf8 fallback?
+ return data.encode("latin1") # XXX: utf8 fallback?
def bytes_to_wsgi(data):
- assert isinstance(data, bytes), 'data must be bytes'
+ assert isinstance(data, bytes), "data must be bytes"
if isinstance(data, str):
return data
else:
- return data.decode('latin1')
+ return data.decode("latin1")
-def quote_header_value(value, extra_chars='', allow_token=True):
+def quote_header_value(value, extra_chars="", allow_token=True):
"""Quote a header value if necessary.
.. versionadded:: 0.5
@@ -199,7 +229,7 @@ def quote_header_value(value, extra_chars='', allow_token=True):
token_chars = _token_chars | set(extra_chars)
if set(value).issubset(token_chars):
return value
- return '"%s"' % value.replace('\\', '\\\\').replace('"', '\\"')
+ return '"%s"' % value.replace("\\", "\\\\").replace('"', '\\"')
def unquote_header_value(value, is_filename=False):
@@ -223,8 +253,8 @@ def unquote_header_value(value, is_filename=False):
# replace sequence below on a UNC path has the effect of turning
# the leading double slash into a single slash and then
# _fix_ie_filename() doesn't work correctly. See #458.
- if not is_filename or value[:2] != '\\\\':
- return value.replace('\\\\', '\\').replace('\\"', '"')
+ if not is_filename or value[:2] != "\\\\":
+ return value.replace("\\\\", "\\").replace('\\"', '"')
return value
@@ -241,8 +271,8 @@ def dump_options_header(header, options):
if value is None:
segments.append(key)
else:
- segments.append('%s=%s' % (key, quote_header_value(value)))
- return '; '.join(segments)
+ segments.append("%s=%s" % (key, quote_header_value(value)))
+ return "; ".join(segments)
def dump_header(iterable, allow_token=True):
@@ -266,14 +296,12 @@ def dump_header(iterable, allow_token=True):
if value is None:
items.append(key)
else:
- items.append('%s=%s' % (
- key,
- quote_header_value(value, allow_token=allow_token)
- ))
+ items.append(
+ "%s=%s" % (key, quote_header_value(value, allow_token=allow_token))
+ )
else:
- items = [quote_header_value(x, allow_token=allow_token)
- for x in iterable]
- return ', '.join(items)
+ items = [quote_header_value(x, allow_token=allow_token) for x in iterable]
+ return ", ".join(items)
def parse_list_header(value):
@@ -337,10 +365,10 @@ def parse_dict_header(value, cls=dict):
# XXX: validate
value = bytes_to_wsgi(value)
for item in _parse_list_header(value):
- if '=' not in item:
+ if "=" not in item:
result[item] = None
continue
- name, value = item.split('=', 1)
+ name, value = item.split("=", 1)
if value[:1] == value[-1:] == '"':
value = unquote_header_value(value[1:-1])
result[name] = value
@@ -369,7 +397,7 @@ def parse_options_header(value, multiple=False):
if multiple=True
"""
if not value:
- return '', {}
+ return "", {}
result = []
@@ -400,9 +428,7 @@ def parse_options_header(value, multiple=False):
continued_encoding = encoding
option = unquote_header_value(option)
if option_value is not None:
- option_value = unquote_header_value(
- option_value,
- option == 'filename')
+ option_value = unquote_header_value(option_value, option == "filename")
if encoding is not None:
option_value = _unquote(option_value).decode(encoding)
if count:
@@ -412,13 +438,13 @@ def parse_options_header(value, multiple=False):
options[option] = options.get(option, "") + option_value
else:
options[option] = option_value
- rest = rest[optmatch.end():]
+ rest = rest[optmatch.end() :]
result.append(options)
if multiple is False:
return tuple(result)
value = rest
- return tuple(result) if result else ('', {})
+ return tuple(result) if result else ("", {})
def parse_accept_header(value, cls=None):
@@ -525,26 +551,27 @@ def parse_authorization_header(value):
auth_type = auth_type.lower()
except ValueError:
return
- if auth_type == b'basic':
+ if auth_type == b"basic":
try:
- username, password = base64.b64decode(auth_info).split(b':', 1)
+ username, password = base64.b64decode(auth_info).split(b":", 1)
except Exception:
return
return Authorization(
- 'basic', {
- 'username': to_unicode(username, _basic_auth_charset),
- 'password': to_unicode(password, _basic_auth_charset)
- }
+ "basic",
+ {
+ "username": to_unicode(username, _basic_auth_charset),
+ "password": to_unicode(password, _basic_auth_charset),
+ },
)
- elif auth_type == b'digest':
+ elif auth_type == b"digest":
auth_map = parse_dict_header(auth_info)
- for key in 'username', 'realm', 'nonce', 'uri', 'response':
+ for key in "username", "realm", "nonce", "uri", "response":
if key not in auth_map:
return
- if 'qop' in auth_map:
- if not auth_map.get('nc') or not auth_map.get('cnonce'):
+ if "qop" in auth_map:
+ if not auth_map.get("nc") or not auth_map.get("cnonce"):
return
- return Authorization('digest', auth_map)
+ return Authorization("digest", auth_map)
def parse_www_authenticate_header(value, on_update=None):
@@ -564,8 +591,7 @@ def parse_www_authenticate_header(value, on_update=None):
auth_type = auth_type.lower()
except (ValueError, AttributeError):
return WWWAuthenticate(value.strip().lower(), on_update=on_update)
- return WWWAuthenticate(auth_type, parse_dict_header(auth_info),
- on_update)
+ return WWWAuthenticate(auth_type, parse_dict_header(auth_info), on_update)
def parse_if_range_header(value):
@@ -591,19 +617,19 @@ def parse_range_header(value, make_inclusive=True):
.. versionadded:: 0.7
"""
- if not value or '=' not in value:
+ if not value or "=" not in value:
return None
ranges = []
last_end = 0
- units, rng = value.split('=', 1)
+ units, rng = value.split("=", 1)
units = units.strip().lower()
- for item in rng.split(','):
+ for item in rng.split(","):
item = item.strip()
- if '-' not in item:
+ if "-" not in item:
return None
- if item.startswith('-'):
+ if item.startswith("-"):
if last_end < 0:
return None
try:
@@ -612,8 +638,8 @@ def parse_range_header(value, make_inclusive=True):
return None
end = None
last_end = -1
- elif '-' in item:
- begin, end = item.split('-', 1)
+ elif "-" in item:
+ begin, end = item.split("-", 1)
begin = begin.strip()
end = end.strip()
if not begin.isdigit():
@@ -650,26 +676,26 @@ def parse_content_range_header(value, on_update=None):
if value is None:
return None
try:
- units, rangedef = (value or '').strip().split(None, 1)
+ units, rangedef = (value or "").strip().split(None, 1)
except ValueError:
return None
- if '/' not in rangedef:
+ if "/" not in rangedef:
return None
- rng, length = rangedef.split('/', 1)
- if length == '*':
+ rng, length = rangedef.split("/", 1)
+ if length == "*":
length = None
elif length.isdigit():
length = int(length)
else:
return None
- if rng == '*':
+ if rng == "*":
return ContentRange(units, None, None, length, on_update=on_update)
- elif '-' not in rng:
+ elif "-" not in rng:
return None
- start, stop = rng.split('-', 1)
+ start, stop = rng.split("-", 1)
try:
start = int(start)
stop = int(stop) + 1
@@ -687,10 +713,10 @@ def quote_etag(etag, weak=False):
:param weak: set to `True` to tag it "weak".
"""
if '"' in etag:
- raise ValueError('invalid etag')
+ raise ValueError("invalid etag")
etag = '"%s"' % etag
if weak:
- etag = 'W/' + etag
+ etag = "W/" + etag
return etag
@@ -709,7 +735,7 @@ def unquote_etag(etag):
return None, None
etag = etag.strip()
weak = False
- if etag.startswith(('W/', 'w/')):
+ if etag.startswith(("W/", "w/")):
weak = True
etag = etag[2:]
if etag[:1] == etag[-1:] == '"':
@@ -734,7 +760,7 @@ def parse_etags(value):
if match is None:
break
is_weak, quoted, raw = match.groups()
- if raw == '*':
+ if raw == "*":
return ETags(star_tag=True)
elif quoted:
raw = quoted
@@ -778,8 +804,7 @@ def parse_date(value):
year += 2000
elif year >= 69 and year <= 99:
year += 1900
- return datetime(*((year,) + t[1:7])) - \
- timedelta(seconds=t[-1] or 0)
+ return datetime(*((year,) + t[1:7])) - timedelta(seconds=t[-1] or 0)
except (ValueError, OverflowError):
return None
@@ -792,12 +817,29 @@ def _dump_date(d, delim):
d = d.utctimetuple()
elif isinstance(d, (integer_types, float)):
d = gmtime(d)
- return '%s, %02d%s%s%s%s %02d:%02d:%02d GMT' % (
- ('Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun')[d.tm_wday],
- d.tm_mday, delim,
- ('Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep',
- 'Oct', 'Nov', 'Dec')[d.tm_mon - 1],
- delim, str(d.tm_year), d.tm_hour, d.tm_min, d.tm_sec
+ return "%s, %02d%s%s%s%s %02d:%02d:%02d GMT" % (
+ ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun")[d.tm_wday],
+ d.tm_mday,
+ delim,
+ (
+ "Jan",
+ "Feb",
+ "Mar",
+ "Apr",
+ "May",
+ "Jun",
+ "Jul",
+ "Aug",
+ "Sep",
+ "Oct",
+ "Nov",
+ "Dec",
+ )[d.tm_mon - 1],
+ delim,
+ str(d.tm_year),
+ d.tm_hour,
+ d.tm_min,
+ d.tm_sec,
)
@@ -813,7 +855,7 @@ def cookie_date(expires=None):
:param expires: If provided that date is used, otherwise the current.
"""
- return _dump_date(expires, '-')
+ return _dump_date(expires, "-")
def http_date(timestamp=None):
@@ -827,7 +869,7 @@ def http_date(timestamp=None):
:param timestamp: If provided that date is used, otherwise the current.
"""
- return _dump_date(timestamp, ' ')
+ return _dump_date(timestamp, " ")
def parse_age(value=None):
@@ -868,13 +910,14 @@ def dump_age(age=None):
age = int(age)
if age < 0:
- raise ValueError('age cannot be negative')
+ raise ValueError("age cannot be negative")
return str(age)
-def is_resource_modified(environ, etag=None, data=None, last_modified=None,
- ignore_if_range=True):
+def is_resource_modified(
+ environ, etag=None, data=None, last_modified=None, ignore_if_range=True
+):
"""Convenience method for conditional requests.
:param environ: the WSGI environment of the request to be checked.
@@ -889,8 +932,8 @@ def is_resource_modified(environ, etag=None, data=None, last_modified=None,
if etag is None and data is not None:
etag = generate_etag(data)
elif data is not None:
- raise TypeError('both data and etag given')
- if environ['REQUEST_METHOD'] not in ('GET', 'HEAD'):
+ raise TypeError("both data and etag given")
+ if environ["REQUEST_METHOD"] not in ("GET", "HEAD"):
return False
unmodified = False
@@ -903,16 +946,16 @@ def is_resource_modified(environ, etag=None, data=None, last_modified=None,
last_modified = last_modified.replace(microsecond=0)
if_range = None
- if not ignore_if_range and 'HTTP_RANGE' in environ:
+ if not ignore_if_range and "HTTP_RANGE" in environ:
# https://tools.ietf.org/html/rfc7233#section-3.2
# A server MUST ignore an If-Range header field received in a request
# that does not contain a Range header field.
- if_range = parse_if_range_header(environ.get('HTTP_IF_RANGE'))
+ if_range = parse_if_range_header(environ.get("HTTP_IF_RANGE"))
if if_range is not None and if_range.date is not None:
modified_since = if_range.date
else:
- modified_since = parse_date(environ.get('HTTP_IF_MODIFIED_SINCE'))
+ modified_since = parse_date(environ.get("HTTP_IF_MODIFIED_SINCE"))
if modified_since and last_modified and last_modified <= modified_since:
unmodified = True
@@ -922,7 +965,7 @@ def is_resource_modified(environ, etag=None, data=None, last_modified=None,
if if_range is not None and if_range.etag is not None:
unmodified = parse_etags(if_range.etag).contains(etag)
else:
- if_none_match = parse_etags(environ.get('HTTP_IF_NONE_MATCH'))
+ if_none_match = parse_etags(environ.get("HTTP_IF_NONE_MATCH"))
if if_none_match:
# https://tools.ietf.org/html/rfc7232#section-3.2
# "A recipient MUST use the weak comparison function when comparing
@@ -932,14 +975,14 @@ def is_resource_modified(environ, etag=None, data=None, last_modified=None,
# https://tools.ietf.org/html/rfc7232#section-3.1
# "Origin server MUST use the strong comparison function when
# comparing entity-tags for If-Match"
- if_match = parse_etags(environ.get('HTTP_IF_MATCH'))
+ if_match = parse_etags(environ.get("HTTP_IF_MATCH"))
if if_match:
unmodified = not if_match.is_strong(etag)
return not unmodified
-def remove_entity_headers(headers, allowed=('expires', 'content-location')):
+def remove_entity_headers(headers, allowed=("expires", "content-location")):
"""Remove all entity headers from a list or :class:`Headers` object. This
operation works in-place. `Expires` and `Content-Location` headers are
by default not removed. The reason for this is :rfc:`2616` section
@@ -953,8 +996,11 @@ def remove_entity_headers(headers, allowed=('expires', 'content-location')):
they are entity headers.
"""
allowed = set(x.lower() for x in allowed)
- headers[:] = [(key, value) for key, value in headers if
- not is_entity_header(key) or key.lower() in allowed]
+ headers[:] = [
+ (key, value)
+ for key, value in headers
+ if not is_entity_header(key) or key.lower() in allowed
+ ]
def remove_hop_by_hop_headers(headers):
@@ -965,8 +1011,9 @@ def remove_hop_by_hop_headers(headers):
:param headers: a list or :class:`Headers` object.
"""
- headers[:] = [(key, value) for key, value in headers if
- not is_hop_by_hop_header(key)]
+ headers[:] = [
+ (key, value) for key, value in headers if not is_hop_by_hop_header(key)
+ ]
def is_entity_header(header):
@@ -991,7 +1038,7 @@ def is_hop_by_hop_header(header):
return header.lower() in _hop_by_hop_headers
-def parse_cookie(header, charset='utf-8', errors='replace', cls=None):
+def parse_cookie(header, charset="utf-8", errors="replace", cls=None):
"""Parse a cookie. Either from a string or WSGI environ.
Per default encoding errors are ignored. If you want a different behavior
@@ -1011,16 +1058,16 @@ def parse_cookie(header, charset='utf-8', errors='replace', cls=None):
used.
"""
if isinstance(header, dict):
- header = header.get('HTTP_COOKIE', '')
+ header = header.get("HTTP_COOKIE", "")
elif header is None:
- header = ''
+ header = ""
# If the value is an unicode string it's mangled through latin1. This
# is done because on PEP 3333 on Python 3 all headers are assumed latin1
# which however is incorrect for cookies, which are sent in page encoding.
# As a result we
if isinstance(header, text_type):
- header = header.encode('latin1', 'replace')
+ header = header.encode("latin1", "replace")
if cls is None:
cls = TypeConversionDict
@@ -1036,10 +1083,20 @@ def parse_cookie(header, charset='utf-8', errors='replace', cls=None):
return cls(_parse_pairs())
-def dump_cookie(key, value='', max_age=None, expires=None, path='/',
- domain=None, secure=False, httponly=False,
- charset='utf-8', sync_expires=True, max_size=4093,
- samesite=None):
+def dump_cookie(
+ key,
+ value="",
+ max_age=None,
+ expires=None,
+ path="/",
+ domain=None,
+ secure=False,
+ httponly=False,
+ charset="utf-8",
+ sync_expires=True,
+ max_size=4093,
+ samesite=None,
+):
"""Creates a new Set-Cookie header without the ``Set-Cookie`` prefix
The parameters are the same as in the cookie Morsel object in the
Python standard library but it accepts unicode data, too.
@@ -1098,21 +1155,23 @@ def dump_cookie(key, value='', max_age=None, expires=None, path='/',
expires = to_bytes(cookie_date(time() + max_age))
samesite = samesite.title() if samesite else None
- if samesite not in ('Strict', 'Lax', None):
+ if samesite not in ("Strict", "Lax", None):
raise ValueError("invalid SameSite value; must be 'Strict', 'Lax' or None")
- buf = [key + b'=' + _cookie_quote(value)]
+ buf = [key + b"=" + _cookie_quote(value)]
# XXX: In theory all of these parameters that are not marked with `None`
# should be quoted. Because stdlib did not quote it before I did not
# want to introduce quoting there now.
- for k, v, q in ((b'Domain', domain, True),
- (b'Expires', expires, False,),
- (b'Max-Age', max_age, False),
- (b'Secure', secure, None),
- (b'HttpOnly', httponly, None),
- (b'Path', path, False),
- (b'SameSite', samesite, False)):
+ for k, v, q in (
+ (b"Domain", domain, True),
+ (b"Expires", expires, False),
+ (b"Max-Age", max_age, False),
+ (b"Secure", secure, None),
+ (b"HttpOnly", httponly, None),
+ (b"Path", path, False),
+ (b"SameSite", samesite, False),
+ ):
if q is None:
if v:
buf.append(k)
@@ -1126,15 +1185,15 @@ def dump_cookie(key, value='', max_age=None, expires=None, path='/',
v = to_bytes(text_type(v), charset)
if q:
v = _cookie_quote(v)
- tmp += b'=' + v
+ tmp += b"=" + v
buf.append(bytes(tmp))
# The return value will be an incorrectly encoded latin1 header on
# Python 3 for consistency with the headers object and a bytestring
# on Python 2 because that's how the API makes more sense.
- rv = b'; '.join(buf)
+ rv = b"; ".join(buf)
if not PY2:
- rv = rv.decode('latin1')
+ rv = rv.decode("latin1")
# Warn if the final value of the cookie is less than the limit. If the
# cookie is too large, then it may be silently ignored, which can be quite
@@ -1145,16 +1204,16 @@ def dump_cookie(key, value='', max_age=None, expires=None, path='/',
value_size = len(value)
warnings.warn(
'The "{key}" cookie is too large: the value was {value_size} bytes'
- ' but the header required {extra_size} extra bytes. The final size'
- ' was {cookie_size} bytes but the limit is {max_size} bytes.'
- ' Browsers may silently ignore cookies larger than this.'.format(
+ " but the header required {extra_size} extra bytes. The final size"
+ " was {cookie_size} bytes but the limit is {max_size} bytes."
+ " Browsers may silently ignore cookies larger than this.".format(
key=key,
value_size=value_size,
extra_size=cookie_size - value_size,
cookie_size=cookie_size,
- max_size=max_size
+ max_size=max_size,
),
- stacklevel=2
+ stacklevel=2,
)
return rv
@@ -1177,19 +1236,23 @@ def is_byte_range_valid(start, stop, length):
# circular dependency fun
-from werkzeug.datastructures import Accept, HeaderSet, ETags, Authorization, \
- WWWAuthenticate, TypeConversionDict, IfRange, Range, ContentRange, \
- RequestCacheControl
-from werkzeug.urls import iri_to_uri
-
+from .datastructures import Accept
+from .datastructures import Authorization
+from .datastructures import ContentRange
+from .datastructures import ETags
+from .datastructures import HeaderSet
+from .datastructures import IfRange
+from .datastructures import Range
+from .datastructures import RequestCacheControl
+from .datastructures import TypeConversionDict
+from .datastructures import WWWAuthenticate
+from .urls import iri_to_uri
# DEPRECATED
-from werkzeug.datastructures import (
- MIMEAccept as _MIMEAccept,
- CharsetAccept as _CharsetAccept,
- LanguageAccept as _LanguageAccept,
- Headers as _Headers,
-)
+from .datastructures import CharsetAccept as _CharsetAccept
+from .datastructures import Headers as _Headers
+from .datastructures import LanguageAccept as _LanguageAccept
+from .datastructures import MIMEAccept as _MIMEAccept
class MIMEAccept(_MIMEAccept):
diff --git a/src/werkzeug/local.py b/src/werkzeug/local.py
index af98c305..9a6088cc 100644
--- a/src/werkzeug/local.py
+++ b/src/werkzeug/local.py
@@ -10,8 +10,10 @@
"""
import copy
from functools import update_wrapper
-from werkzeug.wsgi import ClosingIterator
-from werkzeug._compat import PY2, implements_bool
+
+from ._compat import implements_bool
+from ._compat import PY2
+from .wsgi import ClosingIterator
# since each thread has its own greenlet we can just use those as identifiers
# for the context. If greenlets are not available we fall back to the
@@ -49,11 +51,11 @@ def release_local(local):
class Local(object):
- __slots__ = ('__storage__', '__ident_func__')
+ __slots__ = ("__storage__", "__ident_func__")
def __init__(self):
- object.__setattr__(self, '__storage__', {})
- object.__setattr__(self, '__ident_func__', get_ident)
+ object.__setattr__(self, "__storage__", {})
+ object.__setattr__(self, "__ident_func__", get_ident)
def __iter__(self):
return iter(self.__storage__.items())
@@ -87,7 +89,6 @@ class Local(object):
class LocalStack(object):
-
"""This class works similar to a :class:`Local` but keeps a stack
of objects instead. This is best explained with an example::
@@ -124,7 +125,8 @@ class LocalStack(object):
return self._local.__ident_func__
def _set__ident_func__(self, value):
- object.__setattr__(self._local, '__ident_func__', value)
+ object.__setattr__(self._local, "__ident_func__", value)
+
__ident_func__ = property(_get__ident_func__, _set__ident_func__)
del _get__ident_func__, _set__ident_func__
@@ -132,13 +134,14 @@ class LocalStack(object):
def _lookup():
rv = self.top
if rv is None:
- raise RuntimeError('object unbound')
+ raise RuntimeError("object unbound")
return rv
+
return LocalProxy(_lookup)
def push(self, obj):
"""Pushes a new item to the stack"""
- rv = getattr(self._local, 'stack', None)
+ rv = getattr(self._local, "stack", None)
if rv is None:
self._local.stack = rv = []
rv.append(obj)
@@ -148,7 +151,7 @@ class LocalStack(object):
"""Removes the topmost item from the stack, will return the
old value or `None` if the stack was already empty.
"""
- stack = getattr(self._local, 'stack', None)
+ stack = getattr(self._local, "stack", None)
if stack is None:
return None
elif len(stack) == 1:
@@ -169,7 +172,6 @@ class LocalStack(object):
class LocalManager(object):
-
"""Local objects cannot manage themselves. For that you need a local
manager. You can pass a local manager multiple locals or add them later
by appending them to `manager.locals`. Every time the manager cleans up,
@@ -196,7 +198,7 @@ class LocalManager(object):
if ident_func is not None:
self.ident_func = ident_func
for local in self.locals:
- object.__setattr__(local, '__ident_func__', ident_func)
+ object.__setattr__(local, "__ident_func__", ident_func)
else:
self.ident_func = get_ident
@@ -224,8 +226,10 @@ class LocalManager(object):
"""Wrap a WSGI application so that cleaning up happens after
request end.
"""
+
def application(environ, start_response):
return ClosingIterator(app(environ, start_response), self.cleanup)
+
return application
def middleware(self, func):
@@ -244,15 +248,11 @@ class LocalManager(object):
return update_wrapper(self.make_middleware(func), func)
def __repr__(self):
- return '<%s storages: %d>' % (
- self.__class__.__name__,
- len(self.locals)
- )
+ return "<%s storages: %d>" % (self.__class__.__name__, len(self.locals))
@implements_bool
class LocalProxy(object):
-
"""Acts as a proxy for a werkzeug local. Forwards all operations to
a proxied object. The only operations not supported for forwarding
are right handed operands and any kind of assignment.
@@ -287,40 +287,41 @@ class LocalProxy(object):
.. versionchanged:: 0.6.1
The class can be instantiated with a callable as well now.
"""
- __slots__ = ('__local', '__dict__', '__name__', '__wrapped__')
+
+ __slots__ = ("__local", "__dict__", "__name__", "__wrapped__")
def __init__(self, local, name=None):
- object.__setattr__(self, '_LocalProxy__local', local)
- object.__setattr__(self, '__name__', name)
- if callable(local) and not hasattr(local, '__release_local__'):
+ object.__setattr__(self, "_LocalProxy__local", local)
+ object.__setattr__(self, "__name__", name)
+ if callable(local) and not hasattr(local, "__release_local__"):
# "local" is a callable that is not an instance of Local or
# LocalManager: mark it as a wrapped function.
- object.__setattr__(self, '__wrapped__', local)
+ object.__setattr__(self, "__wrapped__", local)
def _get_current_object(self):
"""Return the current object. This is useful if you want the real
object behind the proxy at a time for performance reasons or because
you want to pass the object into a different context.
"""
- if not hasattr(self.__local, '__release_local__'):
+ if not hasattr(self.__local, "__release_local__"):
return self.__local()
try:
return getattr(self.__local, self.__name__)
except AttributeError:
- raise RuntimeError('no object bound to %s' % self.__name__)
+ raise RuntimeError("no object bound to %s" % self.__name__)
@property
def __dict__(self):
try:
return self._get_current_object().__dict__
except RuntimeError:
- raise AttributeError('__dict__')
+ raise AttributeError("__dict__")
def __repr__(self):
try:
obj = self._get_current_object()
except RuntimeError:
- return '<%s unbound>' % self.__class__.__name__
+ return "<%s unbound>" % self.__class__.__name__
return repr(obj)
def __bool__(self):
@@ -342,7 +343,7 @@ class LocalProxy(object):
return []
def __getattr__(self, name):
- if name == '__members__':
+ if name == "__members__":
return dir(self._get_current_object())
return getattr(self._get_current_object(), name)
diff --git a/src/werkzeug/middleware/http_proxy.py b/src/werkzeug/middleware/http_proxy.py
index 0890cf00..bfdc0712 100644
--- a/src/werkzeug/middleware/http_proxy.py
+++ b/src/werkzeug/middleware/http_proxy.py
@@ -7,7 +7,6 @@ Basic HTTP Proxy
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-
import socket
from ..datastructures import EnvironHeaders
@@ -117,7 +116,7 @@ class ProxyMiddleware(object):
if opts["remove_prefix"]:
remote_path = "%s/%s" % (
target.path.rstrip("/"),
- remote_path[len(prefix):].lstrip("/"),
+ remote_path[len(prefix) :].lstrip("/"),
)
content_length = environ.get("CONTENT_LENGTH")
@@ -179,7 +178,7 @@ class ProxyMiddleware(object):
resp = con.getresponse()
except socket.error:
- from werkzeug.exceptions import BadGateway
+ from ..exceptions import BadGateway
return BadGateway()(environ, start_response)
diff --git a/src/werkzeug/middleware/lint.py b/src/werkzeug/middleware/lint.py
index 0c71f9ca..15372819 100644
--- a/src/werkzeug/middleware/lint.py
+++ b/src/werkzeug/middleware/lint.py
@@ -164,7 +164,7 @@ class GuardedIterator(object):
content_length = headers.get("content-length", type=int)
if status_code == 304:
- for key, value in headers:
+ for key, _value in headers:
key = key.lower()
if key not in ("expires", "content-location") and is_entity_header(
key
diff --git a/src/werkzeug/middleware/profiler.py b/src/werkzeug/middleware/profiler.py
index 64041f61..8e2edc20 100644
--- a/src/werkzeug/middleware/profiler.py
+++ b/src/werkzeug/middleware/profiler.py
@@ -12,6 +12,7 @@ that may be slowing down your application.
:license: BSD-3-Clause
"""
from __future__ import print_function
+
import os.path
import sys
import time
diff --git a/src/werkzeug/middleware/proxy_fix.py b/src/werkzeug/middleware/proxy_fix.py
index 6fcd2ef6..dc1dacc8 100644
--- a/src/werkzeug/middleware/proxy_fix.py
+++ b/src/werkzeug/middleware/proxy_fix.py
@@ -141,7 +141,7 @@ class ProxyFix(object):
"'get_remote_addr' is deprecated as of version 0.15 and"
" will be removed in version 1.0. It is now handled"
" internally for each header.",
- DeprecationWarning
+ DeprecationWarning,
)
return self._get_trusted_comma(self.x_for, ",".join(forwarded_for))
diff --git a/src/werkzeug/middleware/shared_data.py b/src/werkzeug/middleware/shared_data.py
index 5ea3f87e..a902281d 100644
--- a/src/werkzeug/middleware/shared_data.py
+++ b/src/werkzeug/middleware/shared_data.py
@@ -8,7 +8,6 @@ Serve Shared Static Files
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-
import mimetypes
import os
import posixpath
@@ -219,7 +218,7 @@ class SharedDataMiddleware(object):
search_path += "/"
if path.startswith(search_path):
- real_filename, file_loader = loader(path[len(search_path):])
+ real_filename, file_loader = loader(path[len(search_path) :])
if file_loader is not None:
break
diff --git a/src/werkzeug/posixemulation.py b/src/werkzeug/posixemulation.py
index dbf909de..696b4562 100644
--- a/src/werkzeug/posixemulation.py
+++ b/src/werkzeug/posixemulation.py
@@ -17,21 +17,18 @@ r"""
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-import sys
-import os
import errno
-import time
+import os
import random
+import sys
+import time
from ._compat import to_unicode
from .filesystem import get_filesystem_encoding
-
can_rename_open_file = False
-if os.name == 'nt': # pragma: no cover
- _rename = lambda src, dst: False
- _rename_atomic = lambda src, dst: False
+if os.name == "nt":
try:
import ctypes
@@ -47,8 +44,9 @@ if os.name == 'nt': # pragma: no cover
retry = 0
rv = False
while not rv and retry < 100:
- rv = _MoveFileEx(src, dst, _MOVEFILE_REPLACE_EXISTING
- | _MOVEFILE_WRITE_THROUGH)
+ rv = _MoveFileEx(
+ src, dst, _MOVEFILE_REPLACE_EXISTING | _MOVEFILE_WRITE_THROUGH
+ )
if not rv:
time.sleep(0.001)
retry += 1
@@ -62,16 +60,21 @@ if os.name == 'nt': # pragma: no cover
can_rename_open_file = True
def _rename_atomic(src, dst):
- ta = _CreateTransaction(None, 0, 0, 0, 0, 1000, 'Werkzeug rename')
+ ta = _CreateTransaction(None, 0, 0, 0, 0, 1000, "Werkzeug rename")
if ta == -1:
return False
try:
retry = 0
rv = False
while not rv and retry < 100:
- rv = _MoveFileTransacted(src, dst, None, None,
- _MOVEFILE_REPLACE_EXISTING
- | _MOVEFILE_WRITE_THROUGH, ta)
+ rv = _MoveFileTransacted(
+ src,
+ dst,
+ None,
+ None,
+ _MOVEFILE_REPLACE_EXISTING | _MOVEFILE_WRITE_THROUGH,
+ ta,
+ )
if rv:
rv = _CommitTransaction(ta)
break
@@ -81,8 +84,14 @@ if os.name == 'nt': # pragma: no cover
return rv
finally:
_CloseHandle(ta)
+
except Exception:
- pass
+
+ def _rename(src, dst):
+ return False
+
+ def _rename_atomic(src, dst):
+ return False
def rename(src, dst):
# Try atomic or pseudo-atomic rename
@@ -101,6 +110,8 @@ if os.name == 'nt': # pragma: no cover
os.unlink(old)
except Exception:
pass
+
+
else:
rename = os.rename
can_rename_open_file = True
diff --git a/src/werkzeug/routing.py b/src/werkzeug/routing.py
index 3e349814..875c0154 100644
--- a/src/werkzeug/routing.py
+++ b/src/werkzeug/routing.py
@@ -96,30 +96,44 @@
:license: BSD-3-Clause
"""
import difflib
-import re
-import uuid
-import posixpath
import dis
+import posixpath
+import re
import sys
import types
-
+import uuid
from functools import partial
from pprint import pformat
from threading import Lock
-from werkzeug.urls import url_encode, url_quote, url_join, _fast_url_quote
-from werkzeug.utils import redirect, format_string
-from werkzeug.exceptions import HTTPException, NotFound, MethodNotAllowed, \
- BadHost
-from werkzeug._internal import _get_environ, _encode_idna
-from werkzeug._compat import itervalues, iteritems, to_unicode, to_bytes, \
- text_type, string_types, native_string_result, \
- implements_to_string, wsgi_decoding_dance
-from werkzeug.datastructures import ImmutableDict, MultiDict
-from werkzeug.utils import cached_property
-from werkzeug.wsgi import get_host
-
-_rule_re = re.compile(r'''
+from ._compat import implements_to_string
+from ._compat import iteritems
+from ._compat import itervalues
+from ._compat import native_string_result
+from ._compat import string_types
+from ._compat import text_type
+from ._compat import to_bytes
+from ._compat import to_unicode
+from ._compat import wsgi_decoding_dance
+from ._internal import _encode_idna
+from ._internal import _get_environ
+from .datastructures import ImmutableDict
+from .datastructures import MultiDict
+from .exceptions import BadHost
+from .exceptions import HTTPException
+from .exceptions import MethodNotAllowed
+from .exceptions import NotFound
+from .urls import _fast_url_quote
+from .urls import url_encode
+from .urls import url_join
+from .urls import url_quote
+from .utils import cached_property
+from .utils import format_string
+from .utils import redirect
+from .wsgi import get_host
+
+_rule_re = re.compile(
+ r"""
(?P<static>[^<]*) # static rule data
<
(?:
@@ -129,9 +143,12 @@ _rule_re = re.compile(r'''
)?
(?P<variable>[a-zA-Z_][a-zA-Z0-9_]*) # variable name
>
-''', re.VERBOSE)
-_simple_rule_re = re.compile(r'<([^>]+)>')
-_converter_args_re = re.compile(r'''
+ """,
+ re.VERBOSE,
+)
+_simple_rule_re = re.compile(r"<([^>]+)>")
+_converter_args_re = re.compile(
+ r"""
((?P<name>\w+)\s*=\s*)?
(?P<value>
True|False|
@@ -141,14 +158,12 @@ _converter_args_re = re.compile(r'''
[\w\d_.]+|
[urUR]?(?P<stringval>"[^"]*?"|'[^']*')
)\s*,
-''', re.VERBOSE | re.UNICODE)
+ """,
+ re.VERBOSE | re.UNICODE,
+)
-_PYTHON_CONSTANTS = {
- 'None': None,
- 'True': True,
- 'False': False
-}
+_PYTHON_CONSTANTS = {"None": None, "True": True, "False": False}
def _pythonize(value):
@@ -159,25 +174,25 @@ def _pythonize(value):
return convert(value)
except ValueError:
pass
- if value[:1] == value[-1:] and value[0] in '"\'':
+ if value[:1] == value[-1:] and value[0] in "\"'":
value = value[1:-1]
return text_type(value)
def parse_converter_args(argstr):
- argstr += ','
+ argstr += ","
args = []
kwargs = {}
for item in _converter_args_re.finditer(argstr):
- value = item.group('stringval')
+ value = item.group("stringval")
if value is None:
- value = item.group('value')
+ value = item.group("value")
value = _pythonize(value)
- if not item.group('name'):
+ if not item.group("name"):
args.append(value)
else:
- name = item.group('name')
+ name = item.group("name")
kwargs[name] = value
return tuple(args), kwargs
@@ -199,24 +214,23 @@ def parse_rule(rule):
if m is None:
break
data = m.groupdict()
- if data['static']:
- yield None, None, data['static']
- variable = data['variable']
- converter = data['converter'] or 'default'
+ if data["static"]:
+ yield None, None, data["static"]
+ variable = data["variable"]
+ converter = data["converter"] or "default"
if variable in used_names:
- raise ValueError('variable name %r used twice.' % variable)
+ raise ValueError("variable name %r used twice." % variable)
used_names.add(variable)
- yield converter, data['args'] or None, variable
+ yield converter, data["args"] or None, variable
pos = m.end()
if pos < end:
remaining = rule[pos:]
- if '>' in remaining or '<' in remaining:
- raise ValueError('malformed url rule: %r' % rule)
+ if ">" in remaining or "<" in remaining:
+ raise ValueError("malformed url rule: %r" % rule)
yield None, None, remaining
class RoutingException(Exception):
-
"""Special exceptions that require the application to redirect, notifying
about missing urls, etc.
@@ -225,12 +239,12 @@ class RoutingException(Exception):
class RequestRedirect(HTTPException, RoutingException):
-
"""Raise if the map requests a redirect. This is for example the case if
`strict_slashes` are activated and an url that requires a trailing slash.
The attribute `new_url` contains the absolute destination url.
"""
+
code = 308
def __init__(self, new_url):
@@ -242,12 +256,10 @@ class RequestRedirect(HTTPException, RoutingException):
class RequestSlash(RoutingException):
-
"""Internal exception."""
-class RequestAliasRedirect(RoutingException):
-
+class RequestAliasRedirect(RoutingException): # noqa: B903
"""This rule is an alias and wants to redirect to the canonical URL."""
def __init__(self, matched_values):
@@ -256,7 +268,6 @@ class RequestAliasRedirect(RoutingException):
@implements_to_string
class BuildError(RoutingException, LookupError):
-
"""Raised if the build system cannot find a URL for an endpoint with the
values provided.
"""
@@ -274,55 +285,54 @@ class BuildError(RoutingException, LookupError):
def closest_rule(self, adapter):
def _score_rule(rule):
- return sum([
- 0.98 * difflib.SequenceMatcher(
- None, rule.endpoint, self.endpoint
- ).ratio(),
- 0.01 * bool(set(self.values or ()).issubset(rule.arguments)),
- 0.01 * bool(rule.methods and self.method in rule.methods)
- ])
+ return sum(
+ [
+ 0.98
+ * difflib.SequenceMatcher(
+ None, rule.endpoint, self.endpoint
+ ).ratio(),
+ 0.01 * bool(set(self.values or ()).issubset(rule.arguments)),
+ 0.01 * bool(rule.methods and self.method in rule.methods),
+ ]
+ )
if adapter and adapter.map._rules:
return max(adapter.map._rules, key=_score_rule)
def __str__(self):
message = []
- message.append('Could not build url for endpoint %r' % self.endpoint)
+ message.append("Could not build url for endpoint %r" % self.endpoint)
if self.method:
- message.append(' (%r)' % self.method)
+ message.append(" (%r)" % self.method)
if self.values:
- message.append(' with values %r' % sorted(self.values.keys()))
- message.append('.')
+ message.append(" with values %r" % sorted(self.values.keys()))
+ message.append(".")
if self.suggested:
if self.endpoint == self.suggested.endpoint:
if self.method and self.method not in self.suggested.methods:
- message.append(' Did you mean to use methods %r?' % sorted(
- self.suggested.methods
- ))
+ message.append(
+ " Did you mean to use methods %r?"
+ % sorted(self.suggested.methods)
+ )
missing_values = self.suggested.arguments.union(
set(self.suggested.defaults or ())
) - set(self.values.keys())
if missing_values:
message.append(
- ' Did you forget to specify values %r?' %
- sorted(missing_values)
+ " Did you forget to specify values %r?" % sorted(missing_values)
)
else:
- message.append(
- ' Did you mean %r instead?' % self.suggested.endpoint
- )
- return u''.join(message)
+ message.append(" Did you mean %r instead?" % self.suggested.endpoint)
+ return u"".join(message)
class ValidationError(ValueError):
-
"""Validation error. If a rule converter raises this exception the rule
does not match the current URL and the next URL is tried.
"""
class RuleFactory(object):
-
"""As soon as you have more complex URL setups it's a good idea to use rule
factories to avoid repetitive tasks. Some of them are builtin, others can
be added by subclassing `RuleFactory` and overriding `get_rules`.
@@ -335,7 +345,6 @@ class RuleFactory(object):
class Subdomain(RuleFactory):
-
"""All URLs provided by this factory have the subdomain set to a
specific domain. For example if you want to use the subdomain for
the current language this can be a good setup::
@@ -367,7 +376,6 @@ class Subdomain(RuleFactory):
class Submount(RuleFactory):
-
"""Like `Subdomain` but prefixes the URL rule with a given string::
url_map = Map([
@@ -382,7 +390,7 @@ class Submount(RuleFactory):
"""
def __init__(self, path, rules):
- self.path = path.rstrip('/')
+ self.path = path.rstrip("/")
self.rules = rules
def get_rules(self, map):
@@ -394,7 +402,6 @@ class Submount(RuleFactory):
class EndpointPrefix(RuleFactory):
-
"""Prefixes all endpoints (which must be strings for this factory) with
another string. This can be useful for sub applications::
@@ -420,7 +427,6 @@ class EndpointPrefix(RuleFactory):
class RuleTemplate(object):
-
"""Returns copies of the rules wrapped and expands string templates in
the endpoint, rule, defaults or subdomain sections.
@@ -447,7 +453,6 @@ class RuleTemplate(object):
class RuleTemplateFactory(RuleFactory):
-
"""A factory that fills in template variables into rules. Used by
`RuleTemplate` internally.
@@ -480,13 +485,12 @@ class RuleTemplateFactory(RuleFactory):
rule.methods,
rule.build_only,
new_endpoint,
- rule.strict_slashes
+ rule.strict_slashes,
)
@implements_to_string
class Rule(RuleFactory):
-
"""A Rule represents one URL pattern. There are some options for `Rule`
that change the way it behaves and are passed to the `Rule` constructor.
Note that besides the rule-string all arguments *must* be keyword arguments
@@ -600,13 +604,23 @@ class Rule(RuleFactory):
The `alias` and `host` parameters were added.
"""
- def __init__(self, string, defaults=None, subdomain=None, methods=None,
- build_only=False, endpoint=None, strict_slashes=None,
- redirect_to=None, alias=False, host=None):
- if not string.startswith('/'):
- raise ValueError('urls must start with a leading slash')
+ def __init__(
+ self,
+ string,
+ defaults=None,
+ subdomain=None,
+ methods=None,
+ build_only=False,
+ endpoint=None,
+ strict_slashes=None,
+ redirect_to=None,
+ alias=False,
+ host=None,
+ ):
+ if not string.startswith("/"):
+ raise ValueError("urls must start with a leading slash")
self.rule = string
- self.is_leaf = not string.endswith('/')
+ self.is_leaf = not string.endswith("/")
self.map = None
self.strict_slashes = strict_slashes
@@ -619,10 +633,10 @@ class Rule(RuleFactory):
self.methods = None
else:
if isinstance(methods, str):
- raise TypeError('param `methods` should be `Iterable[str]`, not `str`')
+ raise TypeError("param `methods` should be `Iterable[str]`, not `str`")
self.methods = set([x.upper() for x in methods])
- if 'HEAD' not in self.methods and 'GET' in self.methods:
- self.methods.add('HEAD')
+ if "HEAD" not in self.methods and "GET" in self.methods:
+ self.methods.add("HEAD")
self.endpoint = endpoint
self.redirect_to = redirect_to
@@ -657,11 +671,17 @@ class Rule(RuleFactory):
defaults = None
if self.defaults:
defaults = dict(self.defaults)
- return dict(defaults=defaults, subdomain=self.subdomain,
- methods=self.methods, build_only=self.build_only,
- endpoint=self.endpoint, strict_slashes=self.strict_slashes,
- redirect_to=self.redirect_to, alias=self.alias,
- host=self.host)
+ return dict(
+ defaults=defaults,
+ subdomain=self.subdomain,
+ methods=self.methods,
+ build_only=self.build_only,
+ endpoint=self.endpoint,
+ strict_slashes=self.strict_slashes,
+ redirect_to=self.redirect_to,
+ alias=self.alias,
+ host=self.host,
+ )
def get_rules(self, map):
yield self
@@ -681,8 +701,7 @@ class Rule(RuleFactory):
:internal:
"""
if self.map is not None and not rebind:
- raise RuntimeError('url rule %r already bound to map %r' %
- (self, self.map))
+ raise RuntimeError("url rule %r already bound to map %r" % (self, self.map))
self.map = map
if self.strict_slashes is None:
self.strict_slashes = map.strict_slashes
@@ -696,17 +715,17 @@ class Rule(RuleFactory):
.. versionadded:: 0.9
"""
if converter_name not in self.map.converters:
- raise LookupError('the converter %r does not exist' % converter_name)
+ raise LookupError("the converter %r does not exist" % converter_name)
return self.map.converters[converter_name](self.map, *args, **kwargs)
def compile(self):
"""Compiles the regular expression and stores it."""
- assert self.map is not None, 'rule not bound'
+ assert self.map is not None, "rule not bound"
if self.map.host_matching:
- domain_rule = self.host or ''
+ domain_rule = self.host or ""
else:
- domain_rule = self.subdomain or ''
+ domain_rule = self.subdomain or ""
self._trace = []
self._converters = {}
@@ -720,7 +739,7 @@ class Rule(RuleFactory):
if converter is None:
regex_parts.append(re.escape(variable))
self._trace.append((False, variable))
- for part in variable.split('/'):
+ for part in variable.split("/"):
if part:
self._static_weights.append((index, -len(part)))
else:
@@ -729,9 +748,8 @@ class Rule(RuleFactory):
else:
c_args = ()
c_kwargs = {}
- convobj = self.get_converter(
- variable, converter, c_args, c_kwargs)
- regex_parts.append('(?P<%s>%s)' % (variable, convobj.regex))
+ convobj = self.get_converter(variable, converter, c_args, c_kwargs)
+ regex_parts.append("(?P<%s>%s)" % (variable, convobj.regex))
self._converters[variable] = convobj
self._trace.append((True, variable))
self._argument_weights.append(convobj.weight)
@@ -739,21 +757,22 @@ class Rule(RuleFactory):
index = index + 1
_build_regex(domain_rule)
- regex_parts.append('\\|')
- self._trace.append((False, '|'))
- _build_regex(self.rule if self.is_leaf else self.rule.rstrip('/'))
+ regex_parts.append("\\|")
+ self._trace.append((False, "|"))
+ _build_regex(self.rule if self.is_leaf else self.rule.rstrip("/"))
if not self.is_leaf:
- self._trace.append((False, '/'))
+ self._trace.append((False, "/"))
self._build = self._compile_builder(False)
self._build_unknown = self._compile_builder(True)
if self.build_only:
return
- regex = r'^%s%s$' % (
- u''.join(regex_parts),
+ regex = r"^%s%s$" % (
+ u"".join(regex_parts),
(not self.is_leaf or not self.strict_slashes)
- and '(?<!/)(?P<__suffix__>/?)' or ''
+ and "(?<!/)(?P<__suffix__>/?)"
+ or "",
)
self._regex = re.compile(regex, re.UNICODE)
@@ -776,15 +795,19 @@ class Rule(RuleFactory):
# slash and strict slashes enabled. raise an exception that
# tells the map to redirect to the same url but with a
# trailing slash
- if self.strict_slashes and not self.is_leaf and \
- not groups.pop('__suffix__') and \
- (method is None or self.methods is None
- or method in self.methods):
+ if (
+ self.strict_slashes
+ and not self.is_leaf
+ and not groups.pop("__suffix__")
+ and (
+ method is None or self.methods is None or method in self.methods
+ )
+ ):
raise RequestSlash()
# if we are not in strict slashes mode we have to remove
# a __suffix__
elif not self.strict_slashes:
- del groups['__suffix__']
+ del groups["__suffix__"]
result = {}
for name, value in iteritems(groups):
@@ -802,7 +825,7 @@ class Rule(RuleFactory):
return result
class BuilderCompiler:
- JOIN_EMPTY = ''.join
+ JOIN_EMPTY = "".join
if sys.version_info >= (3, 6):
OPARG_SIZE = 256
OPARG_VARI = False
@@ -876,14 +899,14 @@ class Rule(RuleFactory):
if op is not None:
new.append((op, elem))
continue
- if elem == '':
+ if elem == "":
continue
if not new or new[-1][0] is not None:
new.append((op, elem))
continue
new[-1] = (None, new[-1][1] + elem)
if not new:
- new.append((None, ''))
+ new.append((None, ""))
return new
def build_op(self, op, arg=None):
@@ -891,20 +914,17 @@ class Rule(RuleFactory):
if isinstance(op, str):
op = dis.opmap[op]
if arg is None and op >= dis.HAVE_ARGUMENT:
- raise ValueError(
- "Operation requires an argument: %s" % dis.opname[op])
+ raise ValueError("Operation requires an argument: %s" % dis.opname[op])
if arg is not None and op < dis.HAVE_ARGUMENT:
- raise ValueError(
- "Operation takes no argument: %s" % dis.opname[op])
+ raise ValueError("Operation takes no argument: %s" % dis.opname[op])
if arg is None:
arg = 0
# Python 3.6 changed the argument to an 8-bit integer, so this
# could be a practical consideration
if arg >= self.OPARG_SIZE:
- return (
- self.build_op('EXTENDED_ARG', arg // self.OPARG_SIZE)
- + self.build_op(op, arg % self.OPARG_SIZE)
- )
+ return self.build_op(
+ "EXTENDED_ARG", arg // self.OPARG_SIZE
+ ) + self.build_op(op, arg % self.OPARG_SIZE)
if not self.OPARG_VARI:
return bytearray((op, arg))
elif op >= dis.HAVE_ARGUMENT:
@@ -918,58 +938,57 @@ class Rule(RuleFactory):
already be immediately below the string elements on the
stack.
"""
- if 'BUILD_STRING' in dis.opmap:
- return self.build_op('BUILD_STRING', n)
+ if "BUILD_STRING" in dis.opmap:
+ return self.build_op("BUILD_STRING", n)
else:
- return (
- self.build_op('BUILD_TUPLE', n)
- + self.build_op('CALL_FUNCTION', 1)
+ return self.build_op("BUILD_TUPLE", n) + self.build_op(
+ "CALL_FUNCTION", 1
)
def emit_build(
- self, ind, opl, append_unknown=False, encode_query_vars=None,
- kwargs=None
+ self, ind, opl, append_unknown=False, encode_query_vars=None, kwargs=None
):
- ops = b''
+ ops = b""
n = len(opl)
stack = 0
stack_overhead = 0
for op, elem in opl:
if op is None:
- ops += self.build_op('LOAD_CONST', self.get_const(elem))
+ ops += self.build_op("LOAD_CONST", self.get_const(elem))
stack_overhead = 0
continue
- ops += self.build_op('LOAD_CONST', self.get_const(op))
- ops += self.build_op('LOAD_FAST', self.get_var(elem))
- ops += self.build_op('CALL_FUNCTION', 1)
+ ops += self.build_op("LOAD_CONST", self.get_const(op))
+ ops += self.build_op("LOAD_FAST", self.get_var(elem))
+ ops += self.build_op("CALL_FUNCTION", 1)
stack_overhead = 2
stack += len(opl)
peak_stack = stack + stack_overhead
dont_build_string = False
- needs_build_string = 'BUILD_STRING' not in dis.opmap
+ needs_build_string = "BUILD_STRING" not in dis.opmap
if n <= 1:
dont_build_string = True
needs_build_string = False
if append_unknown:
- if 'BUILD_STRING' not in dis.opmap:
+ if "BUILD_STRING" not in dis.opmap:
needs_build_string = True
- ops = self.build_op(
- 'LOAD_CONST', self.get_const(self.JOIN_EMPTY)) + ops
- ops += self.build_op('LOAD_FAST', kwargs)
+ ops = (
+ self.build_op("LOAD_CONST", self.get_const(self.JOIN_EMPTY))
+ + ops
+ )
+ ops += self.build_op("LOAD_FAST", kwargs)
# assemble this in its own buffers because we need to
# jump over it
uops = bytearray() # run if kwargs. TOS=kwargs
- uops += self.build_op(
- 'LOAD_CONST', self.get_const(encode_query_vars))
- uops += self.build_op('ROT_TWO')
- uops += self.build_op('CALL_FUNCTION', 1)
- uops += self.build_op('LOAD_CONST', self.get_const('?'))
- uops += self.build_op('ROT_TWO')
+ uops += self.build_op("LOAD_CONST", self.get_const(encode_query_vars))
+ uops += self.build_op("ROT_TWO")
+ uops += self.build_op("CALL_FUNCTION", 1)
+ uops += self.build_op("LOAD_CONST", self.get_const("?"))
+ uops += self.build_op("ROT_TWO")
if dont_build_string:
uops += self.build_string(n + 2)
@@ -977,23 +996,23 @@ class Rule(RuleFactory):
if not dont_build_string:
# if we're going to build a string, we need to pad out to
# a constant length
- nops += self.build_op('LOAD_CONST', self.get_const(''))
- nops += self.build_op('DUP_TOP')
+ nops += self.build_op("LOAD_CONST", self.get_const(""))
+ nops += self.build_op("DUP_TOP")
elif needs_build_string:
# we inserted the ''.join reference at the bottom of the
# stack, but we don't want to call it: throw it away
- nops += self.build_op('ROT_TWO')
- nops += self.build_op('POP_TOP')
- nops += self.build_op('JUMP_FORWARD', len(uops))
+ nops += self.build_op("ROT_TWO")
+ nops += self.build_op("POP_TOP")
+ nops += self.build_op("JUMP_FORWARD", len(uops))
# this jump needs to take its own length into account. the
# simple way to do that is to compute a minimal guess for the
# length of the jump instruction, and keep revising it upward
- jump_op = self.build_op('JUMP_IF_TRUE_OR_POP', 0)
+ jump_op = self.build_op("JUMP_IF_TRUE_OR_POP", 0)
while True:
jump_len = len(jump_op)
jump_target = ind + len(ops) + jump_len + len(nops)
- jump_op = self.build_op('JUMP_IF_TRUE_OR_POP', jump_target)
+ jump_op = self.build_op("JUMP_IF_TRUE_OR_POP", jump_target)
assert len(jump_op) >= jump_len
if len(jump_op) == jump_len:
break
@@ -1005,8 +1024,7 @@ class Rule(RuleFactory):
n += 2
peak_stack = max(peak_stack, stack + 2)
elif needs_build_string:
- ops = self.build_op(
- 'LOAD_CONST', self.get_const(self.JOIN_EMPTY)) + ops
+ ops = self.build_op("LOAD_CONST", self.get_const(self.JOIN_EMPTY)) + ops
peak_stack += 1
if not dont_build_string:
ops += self.build_string(n)
@@ -1022,35 +1040,41 @@ class Rule(RuleFactory):
url_encode,
charset=self.rule.map.charset,
sort=self.rule.map.sort_parameters,
- key=self.rule.map.sort_key)
+ key=self.rule.map.sort_key,
+ )
for is_dynamic, data in self.rule._trace:
- if data == '|' and opl is dom_ops:
+ if data == "|" and opl is dom_ops:
opl = url_ops
continue
# this seems like a silly case to ever come up but:
# if a default is given for a value that appears in the rule,
# resolve it to a constant ahead of time
if is_dynamic and data in self.defaults:
- data = self.rule._converters[data].to_url(
- self.defaults[data])
+ data = self.rule._converters[data].to_url(self.defaults[data])
is_dynamic = False
if not is_dynamic:
- opl.append((None, url_quote(
- to_bytes(data, self.rule.map.charset), safe='/:|+')))
+ opl.append(
+ (
+ None,
+ url_quote(
+ to_bytes(data, self.rule.map.charset), safe="/:|+"
+ ),
+ )
+ )
continue
opl.append((self.rule._converters[data].to_url, data))
dom_ops = self.collapse_constants(dom_ops)
url_ops = self.collapse_constants(url_ops)
- for op, elem in (dom_ops + url_ops):
+ for op, elem in dom_ops + url_ops:
if op is not None:
self.get_var(elem)
self.add_defaults()
argcount = len(self.var)
# invalid name for paranoia reasons
- self.get_var('.keyword_arguments')
+ self.get_var(".keyword_arguments")
stack = 0
peak_stack = 0
- ops = b''
+ ops = b""
if (
not append_unknown
and len(dom_ops) == len(url_ops) == 1
@@ -1059,8 +1083,7 @@ class Rule(RuleFactory):
# shortcut: just return the constant
stack = peak_stack = 1
constant_value = (dom_ops[0][1], url_ops[0][1])
- ops += self.build_op(
- 'LOAD_CONST', self.get_const(constant_value))
+ ops += self.build_op("LOAD_CONST", self.get_const(constant_value))
else:
ps, rv = self.emit_build(len(ops), dom_ops)
ops += rv
@@ -1068,14 +1091,14 @@ class Rule(RuleFactory):
stack += 1
if append_unknown:
ps, rv = self.emit_build(
- len(ops), url_ops, append_unknown, encode_query_vars,
- argcount)
+ len(ops), url_ops, append_unknown, encode_query_vars, argcount
+ )
else:
ps, rv = self.emit_build(len(ops), url_ops)
ops += rv
peak_stack = max(stack + ps, peak_stack)
- ops += self.build_op('BUILD_TUPLE', 2)
- ops += self.build_op('RETURN_VALUE')
+ ops += self.build_op("BUILD_TUPLE", 2)
+ ops += self.build_op("RETURN_VALUE")
code_args = [
argcount,
len(self.var),
@@ -1085,10 +1108,10 @@ class Rule(RuleFactory):
tuple(self.consts),
(),
tuple(self.var),
- 'generated',
- '<builder:%r>' % self.rule.rule,
+ "generated",
+ "<builder:%r>" % self.rule.rule,
1,
- b''
+ b"",
]
if sys.version_info >= (3,):
code_args[1:1] = [0]
@@ -1124,9 +1147,13 @@ class Rule(RuleFactory):
:internal:
"""
- return not self.build_only and self.defaults and \
- self.endpoint == rule.endpoint and self != rule and \
- self.arguments == rule.arguments
+ return (
+ not self.build_only
+ and self.defaults
+ and self.endpoint == rule.endpoint
+ and self != rule
+ and self.arguments == rule.arguments
+ )
def suitable_for(self, values, method=None):
"""Check if the dict of values has enough data for url generation.
@@ -1135,8 +1162,11 @@ class Rule(RuleFactory):
"""
# if a method was given explicitly and that method is not supported
# by this rule, this rule is not suitable.
- if method is not None and self.methods is not None \
- and method not in self.methods:
+ if (
+ method is not None
+ and self.methods is not None
+ and method not in self.methods
+ ):
return False
defaults = self.defaults or ()
@@ -1174,20 +1204,23 @@ class Rule(RuleFactory):
:internal:
"""
- return bool(self.arguments), -len(self._static_weights), self._static_weights,\
- -len(self._argument_weights), self._argument_weights
+ return (
+ bool(self.arguments),
+ -len(self._static_weights),
+ self._static_weights,
+ -len(self._argument_weights),
+ self._argument_weights,
+ )
def build_compare_key(self):
"""The build compare key for sorting.
:internal:
"""
- return 1 if self.alias else 0, -len(self.arguments), \
- -len(self.defaults or ())
+ return 1 if self.alias else 0, -len(self.arguments), -len(self.defaults or ())
def __eq__(self, other):
- return self.__class__ is other.__class__ and \
- self._trace == other._trace
+ return self.__class__ is other.__class__ and self._trace == other._trace
__hash__ = None
@@ -1200,27 +1233,25 @@ class Rule(RuleFactory):
@native_string_result
def __repr__(self):
if self.map is None:
- return u'<%s (unbound)>' % self.__class__.__name__
+ return u"<%s (unbound)>" % self.__class__.__name__
tmp = []
for is_dynamic, data in self._trace:
if is_dynamic:
- tmp.append(u'<%s>' % data)
+ tmp.append(u"<%s>" % data)
else:
tmp.append(data)
- return u'<%s %s%s -> %s>' % (
+ return u"<%s %s%s -> %s>" % (
self.__class__.__name__,
- repr((u''.join(tmp)).lstrip(u'|')).lstrip(u'u'),
- self.methods is not None
- and u' (%s)' % u', '.join(self.methods)
- or u'',
- self.endpoint
+ repr((u"".join(tmp)).lstrip(u"|")).lstrip(u"u"),
+ self.methods is not None and u" (%s)" % u", ".join(self.methods) or u"",
+ self.endpoint,
)
class BaseConverter(object):
-
"""Base class for all converters."""
- regex = '[^/]+'
+
+ regex = "[^/]+"
weight = 100
def __init__(self, map):
@@ -1234,7 +1265,6 @@ class BaseConverter(object):
class UnicodeConverter(BaseConverter):
-
"""This converter is the default converter and accepts any string but
only one path segment. Thus the string can not include a slash.
@@ -1255,21 +1285,17 @@ class UnicodeConverter(BaseConverter):
def __init__(self, map, minlength=1, maxlength=None, length=None):
BaseConverter.__init__(self, map)
if length is not None:
- length = '{%d}' % int(length)
+ length = "{%d}" % int(length)
else:
if maxlength is None:
- maxlength = ''
+ maxlength = ""
else:
maxlength = int(maxlength)
- length = '{%s,%s}' % (
- int(minlength),
- maxlength
- )
- self.regex = '[^/]' + length
+ length = "{%s,%s}" % (int(minlength), maxlength)
+ self.regex = "[^/]" + length
class AnyConverter(BaseConverter):
-
"""Matches one of the items provided. Items can either be Python
identifiers or strings::
@@ -1282,11 +1308,10 @@ class AnyConverter(BaseConverter):
def __init__(self, map, *items):
BaseConverter.__init__(self, map)
- self.regex = '(?:%s)' % '|'.join([re.escape(x) for x in items])
+ self.regex = "(?:%s)" % "|".join([re.escape(x) for x in items])
class PathConverter(BaseConverter):
-
"""Like the default :class:`UnicodeConverter`, but it also matches
slashes. This is useful for wikis and similar applications::
@@ -1295,16 +1320,17 @@ class PathConverter(BaseConverter):
:param map: the :class:`Map`.
"""
- regex = '[^/].*?'
+
+ regex = "[^/].*?"
weight = 200
class NumberConverter(BaseConverter):
-
"""Baseclass for `IntegerConverter` and `FloatConverter`.
:internal:
"""
+
weight = 50
def __init__(self, map, fixed_digits=0, min=None, max=None, signed=False):
@@ -1317,18 +1343,19 @@ class NumberConverter(BaseConverter):
self.signed = signed
def to_python(self, value):
- if (self.fixed_digits and len(value) != self.fixed_digits):
+ if self.fixed_digits and len(value) != self.fixed_digits:
raise ValidationError()
value = self.num_convert(value)
- if (self.min is not None and value < self.min) or \
- (self.max is not None and value > self.max):
+ if (self.min is not None and value < self.min) or (
+ self.max is not None and value > self.max
+ ):
raise ValidationError()
return value
def to_url(self, value):
value = self.num_convert(value)
if self.fixed_digits:
- value = ('%%0%sd' % self.fixed_digits) % value
+ value = ("%%0%sd" % self.fixed_digits) % value
return str(value)
@property
@@ -1337,7 +1364,6 @@ class NumberConverter(BaseConverter):
class IntegerConverter(NumberConverter):
-
"""This converter only accepts integer values::
Rule("/page/<int:page>")
@@ -1358,12 +1384,12 @@ class IntegerConverter(NumberConverter):
.. versionadded:: 0.15
The ``signed`` parameter.
"""
- regex = r'\d+'
+
+ regex = r"\d+"
num_convert = int
class FloatConverter(NumberConverter):
-
"""This converter only accepts floating point values::
Rule("/probability/<float:probability>")
@@ -1381,7 +1407,8 @@ class FloatConverter(NumberConverter):
.. versionadded:: 0.15
The ``signed`` parameter.
"""
- regex = r'\d+\.\d+'
+
+ regex = r"\d+\.\d+"
num_convert = float
def __init__(self, map, min=None, max=None, signed=False):
@@ -1389,7 +1416,6 @@ class FloatConverter(NumberConverter):
class UUIDConverter(BaseConverter):
-
"""This converter only accepts UUID strings::
Rule('/object/<uuid:identifier>')
@@ -1398,8 +1424,11 @@ class UUIDConverter(BaseConverter):
:param map: the :class:`Map`.
"""
- regex = r'[A-Fa-f0-9]{8}-[A-Fa-f0-9]{4}-' \
- r'[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{12}'
+
+ regex = (
+ r"[A-Fa-f0-9]{8}-[A-Fa-f0-9]{4}-"
+ r"[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{12}"
+ )
def to_python(self, value):
return uuid.UUID(value)
@@ -1410,18 +1439,17 @@ class UUIDConverter(BaseConverter):
#: the default converter mapping for the map.
DEFAULT_CONVERTERS = {
- 'default': UnicodeConverter,
- 'string': UnicodeConverter,
- 'any': AnyConverter,
- 'path': PathConverter,
- 'int': IntegerConverter,
- 'float': FloatConverter,
- 'uuid': UUIDConverter,
+ "default": UnicodeConverter,
+ "string": UnicodeConverter,
+ "any": AnyConverter,
+ "path": PathConverter,
+ "int": IntegerConverter,
+ "float": FloatConverter,
+ "uuid": UUIDConverter,
}
class Map(object):
-
"""The map class stores all the URL rules and some configuration
parameters. Some of the configuration values are only stored on the
`Map` instance since those affect all rules, others are just defaults
@@ -1458,10 +1486,19 @@ class Map(object):
#: A dict of default converters to be used.
default_converters = ImmutableDict(DEFAULT_CONVERTERS)
- def __init__(self, rules=None, default_subdomain='', charset='utf-8',
- strict_slashes=True, redirect_defaults=True,
- converters=None, sort_parameters=False, sort_key=None,
- encoding_errors='replace', host_matching=False):
+ def __init__(
+ self,
+ rules=None,
+ default_subdomain="",
+ charset="utf-8",
+ strict_slashes=True,
+ redirect_defaults=True,
+ converters=None,
+ sort_parameters=False,
+ sort_key=None,
+ encoding_errors="replace",
+ host_matching=False,
+ ):
self._rules = []
self._rules_by_endpoint = {}
self._remap = True
@@ -1528,9 +1565,16 @@ class Map(object):
self._rules_by_endpoint.setdefault(rule.endpoint, []).append(rule)
self._remap = True
- def bind(self, server_name, script_name=None, subdomain=None,
- url_scheme='http', default_method='GET', path_info=None,
- query_args=None):
+ def bind(
+ self,
+ server_name,
+ script_name=None,
+ subdomain=None,
+ url_scheme="http",
+ default_method="GET",
+ path_info=None,
+ query_args=None,
+ ):
"""Return a new :class:`MapAdapter` with the details specified to the
call. Note that `script_name` will default to ``'/'`` if not further
specified or `None`. The `server_name` at least is a requirement
@@ -1559,20 +1603,27 @@ class Map(object):
server_name = server_name.lower()
if self.host_matching:
if subdomain is not None:
- raise RuntimeError('host matching enabled and a '
- 'subdomain was provided')
+ raise RuntimeError("host matching enabled and a subdomain was provided")
elif subdomain is None:
subdomain = self.default_subdomain
if script_name is None:
- script_name = '/'
+ script_name = "/"
if path_info is None:
- path_info = '/'
+ path_info = "/"
try:
server_name = _encode_idna(server_name)
except UnicodeError:
raise BadHost()
- return MapAdapter(self, server_name, script_name, subdomain,
- url_scheme, path_info, default_method, query_args)
+ return MapAdapter(
+ self,
+ server_name,
+ script_name,
+ subdomain,
+ url_scheme,
+ path_info,
+ default_method,
+ query_args,
+ )
def bind_to_environ(self, environ, server_name=None, subdomain=None):
"""Like :meth:`bind` but you can pass it an WSGI environment and it
@@ -1618,8 +1669,8 @@ class Map(object):
server_name = server_name.lower()
if subdomain is None and not self.host_matching:
- cur_server_name = wsgi_server_name.split('.')
- real_server_name = server_name.split('.')
+ cur_server_name = wsgi_server_name.split(".")
+ real_server_name = server_name.split(".")
offset = -len(real_server_name)
if cur_server_name[offset:] != real_server_name:
# This can happen even with valid configs if the server was
@@ -1627,22 +1678,28 @@ class Map(object):
# Instead of raising an exception like in Werkzeug 0.7 or
# earlier we go by an invalid subdomain which will result
# in a 404 error on matching.
- subdomain = '<invalid>'
+ subdomain = "<invalid>"
else:
- subdomain = '.'.join(filter(None, cur_server_name[:offset]))
+ subdomain = ".".join(filter(None, cur_server_name[:offset]))
def _get_wsgi_string(name):
val = environ.get(name)
if val is not None:
return wsgi_decoding_dance(val, self.charset)
- script_name = _get_wsgi_string('SCRIPT_NAME')
- path_info = _get_wsgi_string('PATH_INFO')
- query_args = _get_wsgi_string('QUERY_STRING')
- return Map.bind(self, server_name, script_name,
- subdomain, environ['wsgi.url_scheme'],
- environ['REQUEST_METHOD'], path_info,
- query_args=query_args)
+ script_name = _get_wsgi_string("SCRIPT_NAME")
+ path_info = _get_wsgi_string("PATH_INFO")
+ query_args = _get_wsgi_string("QUERY_STRING")
+ return Map.bind(
+ self,
+ server_name,
+ script_name,
+ subdomain,
+ environ["wsgi.url_scheme"],
+ environ["REQUEST_METHOD"],
+ path_info,
+ query_args=query_args,
+ )
def update(self):
"""Called before matching and building to keep the compiled rules
@@ -1662,7 +1719,7 @@ class Map(object):
def __repr__(self):
rules = self.iter_rules()
- return '%s(%s)' % (self.__class__.__name__, pformat(list(rules)))
+ return "%s(%s)" % (self.__class__.__name__, pformat(list(rules)))
class MapAdapter(object):
@@ -1671,13 +1728,22 @@ class MapAdapter(object):
the URL matching and building based on runtime information.
"""
- def __init__(self, map, server_name, script_name, subdomain,
- url_scheme, path_info, default_method, query_args=None):
+ def __init__(
+ self,
+ map,
+ server_name,
+ script_name,
+ subdomain,
+ url_scheme,
+ path_info,
+ default_method,
+ query_args=None,
+ ):
self.map = map
self.server_name = to_unicode(server_name)
script_name = to_unicode(script_name)
- if not script_name.endswith(u'/'):
- script_name += u'/'
+ if not script_name.endswith(u"/"):
+ script_name += u"/"
self.script_name = script_name
self.subdomain = to_unicode(subdomain)
self.url_scheme = to_unicode(url_scheme)
@@ -1685,8 +1751,9 @@ class MapAdapter(object):
self.default_method = to_unicode(default_method)
self.query_args = query_args
- def dispatch(self, view_func, path_info=None, method=None,
- catch_http_exceptions=False):
+ def dispatch(
+ self, view_func, path_info=None, method=None, catch_http_exceptions=False
+ ):
"""Does the complete dispatching process. `view_func` is called with
the endpoint and a dict with the values for the view. It should
look up the view function, call it, and return a response object
@@ -1740,8 +1807,7 @@ class MapAdapter(object):
return e
raise
- def match(self, path_info=None, method=None, return_rule=False,
- query_args=None):
+ def match(self, path_info=None, method=None, return_rule=False, query_args=None):
"""The usage is simple: you just pass the match method the current
path info as well as the method (which defaults to `GET`). The
following things can then happen:
@@ -1827,9 +1893,9 @@ class MapAdapter(object):
query_args = self.query_args
method = (method or self.default_method).upper()
- path = u'%s|%s' % (
+ path = u"%s|%s" % (
self.map.host_matching and self.server_name or self.subdomain,
- path_info and '/%s' % path_info.lstrip('/')
+ path_info and "/%s" % path_info.lstrip("/"),
)
have_match_for = set()
@@ -1837,12 +1903,18 @@ class MapAdapter(object):
try:
rv = rule.match(path, method)
except RequestSlash:
- raise RequestRedirect(self.make_redirect_url(
- url_quote(path_info, self.map.charset,
- safe='/:|+') + '/', query_args))
+ raise RequestRedirect(
+ self.make_redirect_url(
+ url_quote(path_info, self.map.charset, safe="/:|+") + "/",
+ query_args,
+ )
+ )
except RequestAliasRedirect as e:
- raise RequestRedirect(self.make_alias_redirect_url(
- path, rule.endpoint, e.matched_values, method, query_args))
+ raise RequestRedirect(
+ self.make_alias_redirect_url(
+ path, rule.endpoint, e.matched_values, method, query_args
+ )
+ )
if rv is None:
continue
if rule.methods is not None and method not in rule.methods:
@@ -1850,26 +1922,34 @@ class MapAdapter(object):
continue
if self.map.redirect_defaults:
- redirect_url = self.get_default_redirect(rule, method, rv,
- query_args)
+ redirect_url = self.get_default_redirect(rule, method, rv, query_args)
if redirect_url is not None:
raise RequestRedirect(redirect_url)
if rule.redirect_to is not None:
if isinstance(rule.redirect_to, string_types):
+
def _handle_match(match):
value = rv[match.group(1)]
return rule._converters[match.group(1)].to_url(value)
- redirect_url = _simple_rule_re.sub(_handle_match,
- rule.redirect_to)
+
+ redirect_url = _simple_rule_re.sub(_handle_match, rule.redirect_to)
else:
redirect_url = rule.redirect_to(self, **rv)
- raise RequestRedirect(str(url_join('%s://%s%s%s' % (
- self.url_scheme or 'http',
- self.subdomain + '.' if self.subdomain else '',
- self.server_name,
- self.script_name
- ), redirect_url)))
+ raise RequestRedirect(
+ str(
+ url_join(
+ "%s://%s%s%s"
+ % (
+ self.url_scheme or "http",
+ self.subdomain + "." if self.subdomain else "",
+ self.server_name,
+ self.script_name,
+ ),
+ redirect_url,
+ )
+ )
+ )
if return_rule:
return rule, rv
@@ -1903,7 +1983,7 @@ class MapAdapter(object):
.. versionadded:: 0.7
"""
try:
- self.match(path_info, method='--')
+ self.match(path_info, method="--")
except MethodNotAllowed as e:
return e.valid_methods
except HTTPException:
@@ -1918,13 +1998,13 @@ class MapAdapter(object):
if self.map.host_matching:
if domain_part is None:
return self.server_name
- return to_unicode(domain_part, 'ascii')
+ return to_unicode(domain_part, "ascii")
subdomain = domain_part
if subdomain is None:
subdomain = self.subdomain
else:
- subdomain = to_unicode(subdomain, 'ascii')
- return (subdomain + u'.' if subdomain else u'') + self.server_name
+ subdomain = to_unicode(subdomain, "ascii")
+ return (subdomain + u"." if subdomain else u"") + self.server_name
def get_default_redirect(self, rule, method, values, query_args):
"""A helper that returns the URL to redirect to if it finds one.
@@ -1939,12 +2019,10 @@ class MapAdapter(object):
# with the highest priority up for building.
if r is rule:
break
- if r.provides_defaults_for(rule) and \
- r.suitable_for(values, method):
+ if r.provides_defaults_for(rule) and r.suitable_for(values, method):
values.update(r.defaults)
domain_part, path = r.build(values)
- return self.make_redirect_url(
- path, query_args, domain_part=domain_part)
+ return self.make_redirect_url(path, query_args, domain_part=domain_part)
def encode_query_args(self, query_args):
if not isinstance(query_args, string_types):
@@ -1956,25 +2034,29 @@ class MapAdapter(object):
:internal:
"""
- suffix = ''
+ suffix = ""
if query_args:
- suffix = '?' + self.encode_query_args(query_args)
- return str('%s://%s/%s%s' % (
- self.url_scheme or 'http',
- self.get_host(domain_part),
- posixpath.join(self.script_name[:-1].lstrip('/'),
- path_info.lstrip('/')),
- suffix
- ))
+ suffix = "?" + self.encode_query_args(query_args)
+ return str(
+ "%s://%s/%s%s"
+ % (
+ self.url_scheme or "http",
+ self.get_host(domain_part),
+ posixpath.join(
+ self.script_name[:-1].lstrip("/"), path_info.lstrip("/")
+ ),
+ suffix,
+ )
+ )
def make_alias_redirect_url(self, path, endpoint, values, method, query_args):
"""Internally called to make an alias redirect URL."""
- url = self.build(endpoint, values, method, append_unknown=False,
- force_external=True)
+ url = self.build(
+ endpoint, values, method, append_unknown=False, force_external=True
+ )
if query_args:
- url += '?' + self.encode_query_args(query_args)
- assert url != path, 'detected invalid alias setting. No canonical ' \
- 'URL found'
+ url += "?" + self.encode_query_args(query_args)
+ assert url != path, "detected invalid alias setting. No canonical URL found"
return url
def _partial_build(self, endpoint, values, method, append_unknown):
@@ -1985,8 +2067,9 @@ class MapAdapter(object):
"""
# in case the method is none, try with the default method first
if method is None:
- rv = self._partial_build(endpoint, values, self.default_method,
- append_unknown)
+ rv = self._partial_build(
+ endpoint, values, self.default_method, append_unknown
+ )
if rv is not None:
return rv
@@ -1998,8 +2081,14 @@ class MapAdapter(object):
if rv is not None:
return rv
- def build(self, endpoint, values=None, method=None, force_external=False,
- append_unknown=True):
+ def build(
+ self,
+ endpoint,
+ values=None,
+ method=None,
+ force_external=False,
+ append_unknown=True,
+ ):
"""Building URLs works pretty much the other way round. Instead of
`match` you call `build` and pass it the endpoint and a dict of
arguments for the placeholders.
@@ -2100,10 +2189,13 @@ class MapAdapter(object):
(self.map.host_matching and host == self.server_name)
or (not self.map.host_matching and domain_part == self.subdomain)
):
- return '%s/%s' % (self.script_name.rstrip('/'), path.lstrip('/'))
- return str('%s//%s%s/%s' % (
- self.url_scheme + ':' if self.url_scheme else '',
- host,
- self.script_name[:-1],
- path.lstrip('/')
- ))
+ return "%s/%s" % (self.script_name.rstrip("/"), path.lstrip("/"))
+ return str(
+ "%s//%s%s/%s"
+ % (
+ self.url_scheme + ":" if self.url_scheme else "",
+ host,
+ self.script_name[:-1],
+ path.lstrip("/"),
+ )
+ )
diff --git a/src/werkzeug/security.py b/src/werkzeug/security.py
index 6d2b8de0..1842afd0 100644
--- a/src/werkzeug/security.py
+++ b/src/werkzeug/security.py
@@ -8,31 +8,35 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-import os
-import hmac
+import codecs
import hashlib
+import hmac
+import os
import posixpath
-import codecs
-from struct import Struct
from random import SystemRandom
+from struct import Struct
-from werkzeug._compat import range_type, PY2, text_type, izip, to_bytes, \
- to_native
-
+from ._compat import izip
+from ._compat import PY2
+from ._compat import range_type
+from ._compat import text_type
+from ._compat import to_bytes
+from ._compat import to_native
-SALT_CHARS = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'
+SALT_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
DEFAULT_PBKDF2_ITERATIONS = 150000
-
-_pack_int = Struct('>I').pack
-_builtin_safe_str_cmp = getattr(hmac, 'compare_digest', None)
+_pack_int = Struct(">I").pack
+_builtin_safe_str_cmp = getattr(hmac, "compare_digest", None)
_sys_rng = SystemRandom()
-_os_alt_seps = list(sep for sep in [os.path.sep, os.path.altsep]
- if sep not in (None, '/'))
+_os_alt_seps = list(
+ sep for sep in [os.path.sep, os.path.altsep] if sep not in (None, "/")
+)
-def pbkdf2_hex(data, salt, iterations=DEFAULT_PBKDF2_ITERATIONS,
- keylen=None, hashfunc=None):
+def pbkdf2_hex(
+ data, salt, iterations=DEFAULT_PBKDF2_ITERATIONS, keylen=None, hashfunc=None
+):
"""Like :func:`pbkdf2_bin`, but returns a hex-encoded string.
.. versionadded:: 0.9
@@ -47,11 +51,12 @@ def pbkdf2_hex(data, salt, iterations=DEFAULT_PBKDF2_ITERATIONS,
from the hashlib module. Defaults to sha256.
"""
rv = pbkdf2_bin(data, salt, iterations, keylen, hashfunc)
- return to_native(codecs.encode(rv, 'hex_codec'))
+ return to_native(codecs.encode(rv, "hex_codec"))
-def pbkdf2_bin(data, salt, iterations=DEFAULT_PBKDF2_ITERATIONS,
- keylen=None, hashfunc=None):
+def pbkdf2_bin(
+ data, salt, iterations=DEFAULT_PBKDF2_ITERATIONS, keylen=None, hashfunc=None
+):
"""Returns a binary digest for the PBKDF2 hash algorithm of `data`
with the given `salt`. It iterates `iterations` times and produces a
key of `keylen` bytes. By default, SHA-256 is used as hash function;
@@ -69,14 +74,14 @@ def pbkdf2_bin(data, salt, iterations=DEFAULT_PBKDF2_ITERATIONS,
from the hashlib module. Defaults to sha256.
"""
if not hashfunc:
- hashfunc = 'sha256'
+ hashfunc = "sha256"
data = to_bytes(data)
salt = to_bytes(salt)
if callable(hashfunc):
_test_hash = hashfunc()
- hash_name = getattr(_test_hash, 'name', None)
+ hash_name = getattr(_test_hash, "name", None)
else:
hash_name = hashfunc
return hashlib.pbkdf2_hmac(hash_name, data, salt, iterations, keylen)
@@ -91,9 +96,9 @@ def safe_str_cmp(a, b):
.. versionadded:: 0.7
"""
if isinstance(a, text_type):
- a = a.encode('utf-8')
+ a = a.encode("utf-8")
if isinstance(b, text_type):
- b = b.encode('utf-8')
+ b = b.encode("utf-8")
if _builtin_safe_str_cmp is not None:
return _builtin_safe_str_cmp(a, b)
@@ -115,8 +120,8 @@ def safe_str_cmp(a, b):
def gen_salt(length):
"""Generate a random string of SALT_CHARS with specified ``length``."""
if length <= 0:
- raise ValueError('Salt length must be positive')
- return ''.join(_sys_rng.choice(SALT_CHARS) for _ in range_type(length))
+ raise ValueError("Salt length must be positive")
+ return "".join(_sys_rng.choice(SALT_CHARS) for _ in range_type(length))
def _hash_internal(method, salt, password):
@@ -124,31 +129,31 @@ def _hash_internal(method, salt, password):
unsalted and salted passwords. In case salted passwords are used
hmac is used.
"""
- if method == 'plain':
+ if method == "plain":
return password, method
if isinstance(password, text_type):
- password = password.encode('utf-8')
+ password = password.encode("utf-8")
- if method.startswith('pbkdf2:'):
- args = method[7:].split(':')
+ if method.startswith("pbkdf2:"):
+ args = method[7:].split(":")
if len(args) not in (1, 2):
- raise ValueError('Invalid number of arguments for PBKDF2')
+ raise ValueError("Invalid number of arguments for PBKDF2")
method = args.pop(0)
iterations = args and int(args[0] or 0) or DEFAULT_PBKDF2_ITERATIONS
is_pbkdf2 = True
- actual_method = 'pbkdf2:%s:%d' % (method, iterations)
+ actual_method = "pbkdf2:%s:%d" % (method, iterations)
else:
is_pbkdf2 = False
actual_method = method
if is_pbkdf2:
if not salt:
- raise ValueError('Salt is required for PBKDF2')
+ raise ValueError("Salt is required for PBKDF2")
rv = pbkdf2_hex(password, salt, iterations, hashfunc=method)
elif salt:
if isinstance(salt, text_type):
- salt = salt.encode('utf-8')
+ salt = salt.encode("utf-8")
mac = _create_mac(salt, password, method)
rv = mac.hexdigest()
else:
@@ -159,14 +164,17 @@ def _hash_internal(method, salt, password):
def _create_mac(key, msg, method):
if callable(method):
return hmac.HMAC(key, msg, method)
- hashfunc = lambda d=b'': hashlib.new(method, d)
+
+ def hashfunc(d=b""):
+ return hashlib.new(method, d)
+
# Python 2.7 used ``hasattr(digestmod, '__call__')``
# to detect if hashfunc is callable
hashfunc.__call__ = hashfunc
return hmac.HMAC(key, msg, hashfunc)
-def generate_password_hash(password, method='pbkdf2:sha256', salt_length=8):
+def generate_password_hash(password, method="pbkdf2:sha256", salt_length=8):
"""Hash a password with the given method and salt with a string of
the given length. The format of the string returned includes the method
that was used so that :func:`check_password_hash` can check the hash.
@@ -191,9 +199,9 @@ def generate_password_hash(password, method='pbkdf2:sha256', salt_length=8):
to enable PBKDF2.
:param salt_length: the length of the salt in letters.
"""
- salt = gen_salt(salt_length) if method != 'plain' else ''
+ salt = gen_salt(salt_length) if method != "plain" else ""
h, actual_method = _hash_internal(method, salt, password)
- return '%s$%s$%s' % (actual_method, salt, h)
+ return "%s$%s$%s" % (actual_method, salt, h)
def check_password_hash(pwhash, password):
@@ -207,9 +215,9 @@ def check_password_hash(pwhash, password):
:func:`generate_password_hash`.
:param password: the plaintext password to compare against the hash.
"""
- if pwhash.count('$') < 2:
+ if pwhash.count("$") < 2:
return False
- method, salt, hashval = pwhash.split('$', 2)
+ method, salt, hashval = pwhash.split("$", 2)
return safe_str_cmp(_hash_internal(method, salt, password)[0], hashval)
@@ -222,14 +230,12 @@ def safe_join(directory, *pathnames):
"""
parts = [directory]
for filename in pathnames:
- if filename != '':
+ if filename != "":
filename = posixpath.normpath(filename)
for sep in _os_alt_seps:
if sep in filename:
return None
- if os.path.isabs(filename) or \
- filename == '..' or \
- filename.startswith('../'):
+ if os.path.isabs(filename) or filename == ".." or filename.startswith("../"):
return None
parts.append(filename)
return posixpath.join(*parts)
diff --git a/src/werkzeug/serving.py b/src/werkzeug/serving.py
index 363bb029..4a179b3e 100644
--- a/src/werkzeug/serving.py
+++ b/src/werkzeug/serving.py
@@ -37,72 +37,74 @@
"""
import io
import os
+import signal
import socket
import sys
-import signal
-
-
-can_fork = hasattr(os, "fork")
+import werkzeug
+from ._compat import PY2
+from ._compat import reraise
+from ._compat import WIN
+from ._compat import wsgi_encoding_dance
+from ._internal import _log
+from .exceptions import InternalServerError
+from .urls import uri_to_iri
+from .urls import url_parse
+from .urls import url_unquote
try:
- import termcolor
+ import socketserver
+ from http.server import BaseHTTPRequestHandler
+ from http.server import HTTPServer
except ImportError:
- termcolor = None
+ import SocketServer as socketserver
+ from BaseHTTPServer import HTTPServer
+ from BaseHTTPServer import BaseHTTPRequestHandler
try:
import ssl
except ImportError:
+
class _SslDummy(object):
def __getattr__(self, name):
- raise RuntimeError('SSL support unavailable')
+ raise RuntimeError("SSL support unavailable")
+
ssl = _SslDummy()
+try:
+ import termcolor
+except ImportError:
+ termcolor = None
+
def _get_openssl_crypto_module():
try:
from OpenSSL import crypto
except ImportError:
- raise TypeError('Using ad-hoc certificates requires the pyOpenSSL '
- 'library.')
+ raise TypeError("Using ad-hoc certificates requires the pyOpenSSL library.")
else:
return crypto
-try:
- import SocketServer as socketserver
- from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
-except ImportError:
- import socketserver
- from http.server import HTTPServer, BaseHTTPRequestHandler
-
ThreadingMixIn = socketserver.ThreadingMixIn
+can_fork = hasattr(os, "fork")
if can_fork:
ForkingMixIn = socketserver.ForkingMixIn
else:
+
class ForkingMixIn(object):
pass
+
try:
af_unix = socket.AF_UNIX
except AttributeError:
af_unix = None
-import werkzeug
-from werkzeug._internal import _log
-from werkzeug._compat import PY2, WIN, reraise, wsgi_encoding_dance
-from werkzeug.urls import (
- url_parse,
- url_unquote,
- uri_to_iri,
-)
-from werkzeug.exceptions import InternalServerError
-
-
LISTEN_QUEUE = 128
-can_open_by_fd = not WIN and hasattr(socket, 'fromfd')
+can_open_by_fd = not WIN and hasattr(socket, "fromfd")
# On Python 3, ConnectionError represents the same errnos as
# socket.error from Python 2, while socket.error is an alias for the
@@ -126,12 +128,12 @@ class DechunkedInput(io.RawIOBase):
def read_chunk_len(self):
try:
- line = self._rfile.readline().decode('latin1')
+ line = self._rfile.readline().decode("latin1")
_len = int(line.strip(), 16)
except ValueError:
- raise IOError('Invalid chunk header')
+ raise IOError("Invalid chunk header")
if _len < 0:
- raise IOError('Negative chunk length not allowed')
+ raise IOError("Negative chunk length not allowed")
return _len
def readinto(self, buf):
@@ -152,7 +154,7 @@ class DechunkedInput(io.RawIOBase):
# buffer. If this operation fully consumes the chunk, this will
# reset self._len to 0.
n = min(len(buf), self._len)
- buf[read:read + n] = self._rfile.read(n)
+ buf[read : read + n] = self._rfile.read(n)
self._len -= n
read += n
@@ -160,8 +162,8 @@ class DechunkedInput(io.RawIOBase):
# Skip the terminating newline of a chunk that has been fully
# consumed. This also applies to the 0-sized final chunk
terminator = self._rfile.readline()
- if terminator not in (b'\n', b'\r\n', b'\r'):
- raise IOError('Missing chunk terminating newline')
+ if terminator not in (b"\n", b"\r\n", b"\r"):
+ raise IOError("Missing chunk terminating newline")
return read
@@ -172,7 +174,7 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object):
@property
def server_version(self):
- return 'Werkzeug/' + werkzeug.__version__
+ return "Werkzeug/" + werkzeug.__version__
def make_environ(self):
request_url = url_parse(self.path)
@@ -180,9 +182,9 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object):
def shutdown_server():
self.server.shutdown_signal = True
- url_scheme = 'http' if self.server.ssl_context is None else 'https'
+ url_scheme = "http" if self.server.ssl_context is None else "https"
if not self.client_address:
- self.client_address = '<local>'
+ self.client_address = "<local>"
if isinstance(self.client_address, str):
self.client_address = (self.client_address, 0)
else:
@@ -190,57 +192,57 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object):
path_info = url_unquote(request_url.path)
environ = {
- 'wsgi.version': (1, 0),
- 'wsgi.url_scheme': url_scheme,
- 'wsgi.input': self.rfile,
- 'wsgi.errors': sys.stderr,
- 'wsgi.multithread': self.server.multithread,
- 'wsgi.multiprocess': self.server.multiprocess,
- 'wsgi.run_once': False,
- 'werkzeug.server.shutdown': shutdown_server,
- 'SERVER_SOFTWARE': self.server_version,
- 'REQUEST_METHOD': self.command,
- 'SCRIPT_NAME': '',
- 'PATH_INFO': wsgi_encoding_dance(path_info),
- 'QUERY_STRING': wsgi_encoding_dance(request_url.query),
+ "wsgi.version": (1, 0),
+ "wsgi.url_scheme": url_scheme,
+ "wsgi.input": self.rfile,
+ "wsgi.errors": sys.stderr,
+ "wsgi.multithread": self.server.multithread,
+ "wsgi.multiprocess": self.server.multiprocess,
+ "wsgi.run_once": False,
+ "werkzeug.server.shutdown": shutdown_server,
+ "SERVER_SOFTWARE": self.server_version,
+ "REQUEST_METHOD": self.command,
+ "SCRIPT_NAME": "",
+ "PATH_INFO": wsgi_encoding_dance(path_info),
+ "QUERY_STRING": wsgi_encoding_dance(request_url.query),
# Non-standard, added by mod_wsgi, uWSGI
"REQUEST_URI": wsgi_encoding_dance(self.path),
# Non-standard, added by gunicorn
"RAW_URI": wsgi_encoding_dance(self.path),
- 'REMOTE_ADDR': self.address_string(),
- 'REMOTE_PORT': self.port_integer(),
- 'SERVER_NAME': self.server.server_address[0],
- 'SERVER_PORT': str(self.server.server_address[1]),
- 'SERVER_PROTOCOL': self.request_version
+ "REMOTE_ADDR": self.address_string(),
+ "REMOTE_PORT": self.port_integer(),
+ "SERVER_NAME": self.server.server_address[0],
+ "SERVER_PORT": str(self.server.server_address[1]),
+ "SERVER_PROTOCOL": self.request_version,
}
for key, value in self.get_header_items():
- key = key.upper().replace('-', '_')
- if key not in ('CONTENT_TYPE', 'CONTENT_LENGTH'):
- key = 'HTTP_' + key
+ key = key.upper().replace("-", "_")
+ if key not in ("CONTENT_TYPE", "CONTENT_LENGTH"):
+ key = "HTTP_" + key
if key in environ:
value = "{},{}".format(environ[key], value)
environ[key] = value
- if environ.get('HTTP_TRANSFER_ENCODING', '').strip().lower() == 'chunked':
- environ['wsgi.input_terminated'] = True
- environ['wsgi.input'] = DechunkedInput(environ['wsgi.input'])
+ if environ.get("HTTP_TRANSFER_ENCODING", "").strip().lower() == "chunked":
+ environ["wsgi.input_terminated"] = True
+ environ["wsgi.input"] = DechunkedInput(environ["wsgi.input"])
if request_url.scheme and request_url.netloc:
- environ['HTTP_HOST'] = request_url.netloc
+ environ["HTTP_HOST"] = request_url.netloc
return environ
def run_wsgi(self):
- if self.headers.get('Expect', '').lower().strip() == '100-continue':
- self.wfile.write(b'HTTP/1.1 100 Continue\r\n\r\n')
+ if self.headers.get("Expect", "").lower().strip() == "100-continue":
+ self.wfile.write(b"HTTP/1.1 100 Continue\r\n\r\n")
self.environ = environ = self.make_environ()
headers_set = []
headers_sent = []
def write(data):
- assert headers_set, 'write() before start_response'
+ assert headers_set, "write() before start_response"
if not headers_sent:
status, response_headers = headers_sent[:] = headers_set
try:
@@ -254,18 +256,21 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object):
self.send_header(key, value)
key = key.lower()
header_keys.add(key)
- if not ('content-length' in header_keys
- or environ['REQUEST_METHOD'] == 'HEAD'
- or code < 200 or code in (204, 304)):
+ if not (
+ "content-length" in header_keys
+ or environ["REQUEST_METHOD"] == "HEAD"
+ or code < 200
+ or code in (204, 304)
+ ):
self.close_connection = True
- self.send_header('Connection', 'close')
- if 'server' not in header_keys:
- self.send_header('Server', self.version_string())
- if 'date' not in header_keys:
- self.send_header('Date', self.date_time_string())
+ self.send_header("Connection", "close")
+ if "server" not in header_keys:
+ self.send_header("Server", self.version_string())
+ if "date" not in header_keys:
+ self.send_header("Date", self.date_time_string())
self.end_headers()
- assert isinstance(data, bytes), 'applications must write bytes'
+ assert isinstance(data, bytes), "applications must write bytes"
self.wfile.write(data)
self.wfile.flush()
@@ -277,7 +282,7 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object):
finally:
exc_info = None
elif headers_set:
- raise AssertionError('Headers already set')
+ raise AssertionError("Headers already set")
headers_set[:] = [status, response_headers]
return write
@@ -287,9 +292,9 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object):
for data in application_iter:
write(data)
if not headers_sent:
- write(b'')
+ write(b"")
finally:
- if hasattr(application_iter, 'close'):
+ if hasattr(application_iter, "close"):
application_iter.close()
application_iter = None
@@ -300,7 +305,8 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object):
except Exception:
if self.server.passthrough_errors:
raise
- from werkzeug.debug.tbtools import get_current_traceback
+ from .debug.tbtools import get_current_traceback
+
traceback = get_current_traceback(ignore_system_exceptions=True)
try:
# if we haven't yet sent the headers but they are set
@@ -310,8 +316,7 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object):
execute(InternalServerError())
except Exception:
pass
- self.server.log('error', 'Error on request:\n%s',
- traceback.plaintext)
+ self.server.log("error", "Error on request:\n%s", traceback.plaintext)
def handle(self):
"""Handles a request ignoring dropped connections."""
@@ -332,7 +337,7 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object):
later. It's the best we can do.
"""
# Windows does not provide SIGKILL, go with SIGTERM then.
- sig = getattr(signal, 'SIGKILL', signal.SIGTERM)
+ sig = getattr(signal, "SIGKILL", signal.SIGTERM)
# reloader active
if is_running_from_reloader():
os.kill(os.getpid(), sig)
@@ -358,19 +363,19 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object):
"""Send the response header and log the response code."""
self.log_request(code)
if message is None:
- message = code in self.responses and self.responses[code][0] or ''
- if self.request_version != 'HTTP/0.9':
+ message = code in self.responses and self.responses[code][0] or ""
+ if self.request_version != "HTTP/0.9":
hdr = "%s %d %s\r\n" % (self.protocol_version, code, message)
- self.wfile.write(hdr.encode('ascii'))
+ self.wfile.write(hdr.encode("ascii"))
def version_string(self):
return BaseHTTPRequestHandler.version_string(self).strip()
def address_string(self):
- if getattr(self, 'environ', None):
- return self.environ['REMOTE_ADDR']
+ if getattr(self, "environ", None):
+ return self.environ["REMOTE_ADDR"]
elif not self.client_address:
- return '<local>'
+ return "<local>"
elif isinstance(self.client_address, str):
return self.client_address
else:
@@ -379,7 +384,7 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object):
def port_integer(self):
return self.client_address[1]
- def log_request(self, code='-', size='-'):
+ def log_request(self, code="-", size="-"):
try:
path = uri_to_iri(self.path)
msg = "%s %s %s" % (self.command, path, self.request_version)
@@ -392,33 +397,35 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object):
if termcolor:
color = termcolor.colored
- if code[0] == '1': # 1xx - Informational
- msg = color(msg, attrs=['bold'])
- elif code[0] == '2': # 2xx - Success
- msg = color(msg, color='white')
- elif code == '304': # 304 - Resource Not Modified
- msg = color(msg, color='cyan')
- elif code[0] == '3': # 3xx - Redirection
- msg = color(msg, color='green')
- elif code == '404': # 404 - Resource Not Found
- msg = color(msg, color='yellow')
- elif code[0] == '4': # 4xx - Client Error
- msg = color(msg, color='red', attrs=['bold'])
- else: # 5xx, or any other response
- msg = color(msg, color='magenta', attrs=['bold'])
-
- self.log('info', '"%s" %s %s', msg, code, size)
+ if code[0] == "1": # 1xx - Informational
+ msg = color(msg, attrs=["bold"])
+ elif code[0] == "2": # 2xx - Success
+ msg = color(msg, color="white")
+ elif code == "304": # 304 - Resource Not Modified
+ msg = color(msg, color="cyan")
+ elif code[0] == "3": # 3xx - Redirection
+ msg = color(msg, color="green")
+ elif code == "404": # 404 - Resource Not Found
+ msg = color(msg, color="yellow")
+ elif code[0] == "4": # 4xx - Client Error
+ msg = color(msg, color="red", attrs=["bold"])
+ else: # 5xx, or any other response
+ msg = color(msg, color="magenta", attrs=["bold"])
+
+ self.log("info", '"%s" %s %s', msg, code, size)
def log_error(self, *args):
- self.log('error', *args)
+ self.log("error", *args)
def log_message(self, format, *args):
- self.log('info', format, *args)
+ self.log("info", format, *args)
def log(self, type, message, *args):
- _log(type, '%s - - [%s] %s\n' % (self.address_string(),
- self.log_date_time_string(),
- message % args))
+ _log(
+ type,
+ "%s - - [%s] %s\n"
+ % (self.address_string(), self.log_date_time_string(), message % args),
+ )
def get_header_items(self):
"""
@@ -434,15 +441,18 @@ class WSGIRequestHandler(BaseHTTPRequestHandler, object):
:return: List of tuples containing header hey/value pairs
"""
if PY2:
- # For Python 2, process the headers manually according to W3C RFC 2616 Section 4.2
+ # For Python 2, process the headers manually according to
+ # W3C RFC 2616 Section 4.2.
items = []
for header in self.headers.headers:
- # Remove the \n\r from the header and split on the : to get the field name and value
+ # Remove "\n\r" from the header and split on ":" to get
+ # the field name and value.
key, value = header[0:-2].split(":", 1)
- # Add the key and the value once stripped of leading white space. The specification
- # allows for stripping trailing white space but the Python 3 code does not strip
- # trailing white space. Therefore, trailing space will be left as is to match the
- # Python 3 behavior
+ # Add the key and the value once stripped of leading
+ # white space. The specification allows for stripping
+ # trailing white space but the Python 3 code does not
+ # strip trailing white space. Therefore, trailing space
+ # will be left as is to match the Python 3 behavior.
items.append((key, value.lstrip()))
else:
items = self.headers.items()
@@ -456,11 +466,12 @@ BaseRequestHandler = WSGIRequestHandler
def generate_adhoc_ssl_pair(cn=None):
from random import random
+
crypto = _get_openssl_crypto_module()
# pretty damn sure that this is not actually accepted by anyone
if cn is None:
- cn = '*'
+ cn = "*"
cert = crypto.X509()
cert.set_serial_number(int(random() * sys.maxsize))
@@ -469,7 +480,7 @@ def generate_adhoc_ssl_pair(cn=None):
subject = cert.get_subject()
subject.CN = cn
- subject.O = 'Dummy Certificate' # noqa: E741
+ subject.O = "Dummy Certificate" # noqa: E741
issuer = cert.get_issuer()
issuer.CN = subject.CN
@@ -478,7 +489,7 @@ def generate_adhoc_ssl_pair(cn=None):
pkey = crypto.PKey()
pkey.generate_key(crypto.TYPE_RSA, 2048)
cert.set_pubkey(pkey)
- cert.sign(pkey, 'sha256')
+ cert.sign(pkey, "sha256")
return cert, pkey
@@ -502,16 +513,17 @@ def make_ssl_devcert(base_path, host=None, cn=None):
:param cn: the `CN` to use.
"""
from OpenSSL import crypto
+
if host is not None:
- cn = '*.%s/CN=%s' % (host, host)
+ cn = "*.%s/CN=%s" % (host, host)
cert, pkey = generate_adhoc_ssl_pair(cn=cn)
- cert_file = base_path + '.crt'
- pkey_file = base_path + '.key'
+ cert_file = base_path + ".crt"
+ pkey_file = base_path + ".key"
- with open(cert_file, 'wb') as f:
+ with open(cert_file, "wb") as f:
f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert))
- with open(pkey_file, 'wb') as f:
+ with open(pkey_file, "wb") as f:
f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, pkey))
return cert_file, pkey_file
@@ -557,8 +569,8 @@ def load_ssl_context(cert_file, pkey_file=None, protocol=None):
class _SSLContext(object):
- '''A dummy class with a small subset of Python3's ``ssl.SSLContext``, only
- intended to be used with and by Werkzeug.'''
+ """A dummy class with a small subset of Python3's ``ssl.SSLContext``, only
+ intended to be used with and by Werkzeug."""
def __init__(self, protocol):
self._protocol = protocol
@@ -572,9 +584,13 @@ class _SSLContext(object):
self._password = password
def wrap_socket(self, sock, **kwargs):
- return ssl.wrap_socket(sock, keyfile=self._keyfile,
- certfile=self._certfile,
- ssl_version=self._protocol, **kwargs)
+ return ssl.wrap_socket(
+ sock,
+ keyfile=self._keyfile,
+ certfile=self._certfile,
+ ssl_version=self._protocol,
+ **kwargs
+ )
def is_ssl_error(error=None):
@@ -582,6 +598,7 @@ def is_ssl_error(error=None):
exc_types = (ssl.SSLError,)
try:
from OpenSSL.SSL import Error
+
exc_types += (Error,)
except ImportError:
pass
@@ -606,9 +623,9 @@ def select_address_family(host, port):
# return info[0][0]
# except socket.gaierror:
# pass
- if host.startswith('unix://'):
+ if host.startswith("unix://"):
return socket.AF_UNIX
- elif ':' in host and hasattr(socket, 'AF_INET6'):
+ elif ":" in host and hasattr(socket, "AF_INET6"):
return socket.AF_INET6
return socket.AF_INET
@@ -617,10 +634,11 @@ def get_sockaddr(host, port, family):
"""Return a fully qualified socket address that can be passed to
:func:`socket.bind`."""
if family == af_unix:
- return host.split('://', 1)[1]
+ return host.split("://", 1)[1]
try:
res = socket.getaddrinfo(
- host, port, family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
+ host, port, family, socket.SOCK_STREAM, socket.IPPROTO_TCP
+ )
except socket.gaierror:
return host, port
return res[0][4]
@@ -629,13 +647,20 @@ def get_sockaddr(host, port, family):
class BaseWSGIServer(HTTPServer, object):
"""Simple single-threaded, single-process WSGI server."""
+
multithread = False
multiprocess = False
request_queue_size = LISTEN_QUEUE
def __init__(
- self, host, port, app, handler=None, passthrough_errors=False,
- ssl_context=None, fd=None
+ self,
+ host,
+ port,
+ app,
+ handler=None,
+ passthrough_errors=False,
+ ssl_context=None,
+ fd=None,
):
if handler is None:
handler = WSGIRequestHandler
@@ -643,17 +668,13 @@ class BaseWSGIServer(HTTPServer, object):
self.address_family = select_address_family(host, port)
if fd is not None:
- real_sock = socket.fromfd(
- fd, self.address_family, socket.SOCK_STREAM)
+ real_sock = socket.fromfd(fd, self.address_family, socket.SOCK_STREAM)
port = 0
server_address = get_sockaddr(host, int(port), self.address_family)
# remove socket file if it already exists
- if (
- self.address_family == af_unix
- and os.path.exists(server_address)
- ):
+ if self.address_family == af_unix and os.path.exists(server_address):
os.unlink(server_address)
HTTPServer.__init__(self, server_address, handler)
@@ -672,7 +693,7 @@ class BaseWSGIServer(HTTPServer, object):
if ssl_context is not None:
if isinstance(ssl_context, tuple):
ssl_context = load_ssl_context(*ssl_context)
- if ssl_context == 'adhoc':
+ if ssl_context == "adhoc":
ssl_context = generate_adhoc_ssl_context()
# If we are on Python 2 the return value from socket.fromfd
# is an internal socket object but what we need for ssl wrap
@@ -714,6 +735,7 @@ class BaseWSGIServer(HTTPServer, object):
class ThreadedWSGIServer(ThreadingMixIn, BaseWSGIServer):
"""A WSGI server that does threading."""
+
multithread = True
daemon_threads = True
@@ -721,35 +743,63 @@ class ThreadedWSGIServer(ThreadingMixIn, BaseWSGIServer):
class ForkingWSGIServer(ForkingMixIn, BaseWSGIServer):
"""A WSGI server that does forking."""
+
multiprocess = True
- def __init__(self, host, port, app, processes=40, handler=None,
- passthrough_errors=False, ssl_context=None, fd=None):
+ def __init__(
+ self,
+ host,
+ port,
+ app,
+ processes=40,
+ handler=None,
+ passthrough_errors=False,
+ ssl_context=None,
+ fd=None,
+ ):
if not can_fork:
- raise ValueError('Your platform does not support forking.')
- BaseWSGIServer.__init__(self, host, port, app, handler,
- passthrough_errors, ssl_context, fd)
+ raise ValueError("Your platform does not support forking.")
+ BaseWSGIServer.__init__(
+ self, host, port, app, handler, passthrough_errors, ssl_context, fd
+ )
self.max_children = processes
-def make_server(host=None, port=None, app=None, threaded=False, processes=1,
- request_handler=None, passthrough_errors=False,
- ssl_context=None, fd=None):
+def make_server(
+ host=None,
+ port=None,
+ app=None,
+ threaded=False,
+ processes=1,
+ request_handler=None,
+ passthrough_errors=False,
+ ssl_context=None,
+ fd=None,
+):
"""Create a new server instance that is either threaded, or forks
or just processes one request after another.
"""
if threaded and processes > 1:
- raise ValueError("cannot have a multithreaded and "
- "multi process server.")
+ raise ValueError("cannot have a multithreaded and multi process server.")
elif threaded:
- return ThreadedWSGIServer(host, port, app, request_handler,
- passthrough_errors, ssl_context, fd=fd)
+ return ThreadedWSGIServer(
+ host, port, app, request_handler, passthrough_errors, ssl_context, fd=fd
+ )
elif processes > 1:
- return ForkingWSGIServer(host, port, app, processes, request_handler,
- passthrough_errors, ssl_context, fd=fd)
+ return ForkingWSGIServer(
+ host,
+ port,
+ app,
+ processes,
+ request_handler,
+ passthrough_errors,
+ ssl_context,
+ fd=fd,
+ )
else:
- return BaseWSGIServer(host, port, app, request_handler,
- passthrough_errors, ssl_context, fd=fd)
+ return BaseWSGIServer(
+ host, port, app, request_handler, passthrough_errors, ssl_context, fd=fd
+ )
def is_running_from_reloader():
@@ -758,15 +808,26 @@ def is_running_from_reloader():
.. versionadded:: 0.10
"""
- return os.environ.get('WERKZEUG_RUN_MAIN') == 'true'
-
-
-def run_simple(hostname, port, application, use_reloader=False,
- use_debugger=False, use_evalex=True,
- extra_files=None, reloader_interval=1,
- reloader_type='auto', threaded=False,
- processes=1, request_handler=None, static_files=None,
- passthrough_errors=False, ssl_context=None):
+ return os.environ.get("WERKZEUG_RUN_MAIN") == "true"
+
+
+def run_simple(
+ hostname,
+ port,
+ application,
+ use_reloader=False,
+ use_debugger=False,
+ use_evalex=True,
+ extra_files=None,
+ reloader_interval=1,
+ reloader_type="auto",
+ threaded=False,
+ processes=1,
+ request_handler=None,
+ static_files=None,
+ passthrough_errors=False,
+ ssl_context=None,
+):
"""Start a WSGI application. Optional features include a reloader,
multithreading and fork support.
@@ -837,36 +898,50 @@ def run_simple(hostname, port, application, use_reloader=False,
to disable SSL (which is the default).
"""
if not isinstance(port, int):
- raise TypeError('port must be an integer')
+ raise TypeError("port must be an integer")
if use_debugger:
- from werkzeug.debug import DebuggedApplication
+ from .debug import DebuggedApplication
+
application = DebuggedApplication(application, use_evalex)
if static_files:
- from werkzeug.wsgi import SharedDataMiddleware
+ from .middleware.shared_data import SharedDataMiddleware
+
application = SharedDataMiddleware(application, static_files)
def log_startup(sock):
- display_hostname = hostname if hostname not in ('', '*') else 'localhost'
- quit_msg = '(Press CTRL+C to quit)'
+ display_hostname = hostname if hostname not in ("", "*") else "localhost"
+ quit_msg = "(Press CTRL+C to quit)"
if sock.family == af_unix:
- _log('info', ' * Running on %s %s', display_hostname, quit_msg)
+ _log("info", " * Running on %s %s", display_hostname, quit_msg)
else:
- if ':' in display_hostname:
- display_hostname = '[%s]' % display_hostname
+ if ":" in display_hostname:
+ display_hostname = "[%s]" % display_hostname
port = sock.getsockname()[1]
- _log('info', ' * Running on %s://%s:%d/ %s',
- 'http' if ssl_context is None else 'https',
- display_hostname, port, quit_msg)
+ _log(
+ "info",
+ " * Running on %s://%s:%d/ %s",
+ "http" if ssl_context is None else "https",
+ display_hostname,
+ port,
+ quit_msg,
+ )
def inner():
try:
- fd = int(os.environ['WERKZEUG_SERVER_FD'])
+ fd = int(os.environ["WERKZEUG_SERVER_FD"])
except (LookupError, ValueError):
fd = None
- srv = make_server(hostname, port, application, threaded,
- processes, request_handler,
- passthrough_errors, ssl_context,
- fd=fd)
+ srv = make_server(
+ hostname,
+ port,
+ application,
+ threaded,
+ processes,
+ request_handler,
+ passthrough_errors,
+ ssl_context,
+ fd=fd,
+ )
if fd is None:
log_startup(srv.socket)
srv.serve_forever()
@@ -877,9 +952,11 @@ def run_simple(hostname, port, application, use_reloader=False,
# port is actually available.
if not is_running_from_reloader():
if port == 0 and not can_open_by_fd:
- raise ValueError('Cannot bind to a random port with enabled '
- 'reloader if the Python interpreter does '
- 'not support socket opening by fd.')
+ raise ValueError(
+ "Cannot bind to a random port with enabled "
+ "reloader if the Python interpreter does "
+ "not support socket opening by fd."
+ )
# Create and destroy a socket so that any exceptions are
# raised before we spawn a separate Python interpreter and
@@ -889,26 +966,26 @@ def run_simple(hostname, port, application, use_reloader=False,
s = socket.socket(address_family, socket.SOCK_STREAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(server_address)
- if hasattr(s, 'set_inheritable'):
+ if hasattr(s, "set_inheritable"):
s.set_inheritable(True)
# If we can open the socket by file descriptor, then we can just
# reuse this one and our socket will survive the restarts.
if can_open_by_fd:
- os.environ['WERKZEUG_SERVER_FD'] = str(s.fileno())
+ os.environ["WERKZEUG_SERVER_FD"] = str(s.fileno())
s.listen(LISTEN_QUEUE)
log_startup(s)
else:
s.close()
if address_family == af_unix:
- _log('info', "Unlinking %s" % server_address)
+ _log("info", "Unlinking %s" % server_address)
os.unlink(server_address)
# Do not use relative imports, otherwise "python -m werkzeug.serving"
# breaks.
- from werkzeug._reloader import run_with_reloader
- run_with_reloader(inner, extra_files, reloader_interval,
- reloader_type)
+ from ._reloader import run_with_reloader
+
+ run_with_reloader(inner, extra_files, reloader_interval, reloader_type)
else:
inner()
@@ -916,46 +993,63 @@ def run_simple(hostname, port, application, use_reloader=False,
def run_with_reloader(*args, **kwargs):
# People keep using undocumented APIs. Do not use this function
# please, we do not guarantee that it continues working.
- from werkzeug._reloader import run_with_reloader
+ from ._reloader import run_with_reloader
+
return run_with_reloader(*args, **kwargs)
def main():
- '''A simple command-line interface for :py:func:`run_simple`.'''
+ """A simple command-line interface for :py:func:`run_simple`."""
# in contrast to argparse, this works at least under Python < 2.7
import optparse
- from werkzeug.utils import import_string
-
- parser = optparse.OptionParser(
- usage='Usage: %prog [options] app_module:app_object')
- parser.add_option('-b', '--bind', dest='address',
- help='The hostname:port the app should listen on.')
- parser.add_option('-d', '--debug', dest='use_debugger',
- action='store_true', default=False,
- help='Use Werkzeug\'s debugger.')
- parser.add_option('-r', '--reload', dest='use_reloader',
- action='store_true', default=False,
- help='Reload Python process if modules change.')
+ from .utils import import_string
+
+ parser = optparse.OptionParser(usage="Usage: %prog [options] app_module:app_object")
+ parser.add_option(
+ "-b",
+ "--bind",
+ dest="address",
+ help="The hostname:port the app should listen on.",
+ )
+ parser.add_option(
+ "-d",
+ "--debug",
+ dest="use_debugger",
+ action="store_true",
+ default=False,
+ help="Use Werkzeug's debugger.",
+ )
+ parser.add_option(
+ "-r",
+ "--reload",
+ dest="use_reloader",
+ action="store_true",
+ default=False,
+ help="Reload Python process if modules change.",
+ )
options, args = parser.parse_args()
hostname, port = None, None
if options.address:
- address = options.address.split(':')
+ address = options.address.split(":")
hostname = address[0]
if len(address) > 1:
port = address[1]
if len(args) != 1:
- sys.stdout.write('No application supplied, or too much. See --help\n')
+ sys.stdout.write("No application supplied, or too much. See --help\n")
sys.exit(1)
app = import_string(args[0])
run_simple(
- hostname=(hostname or '127.0.0.1'), port=int(port or 5000),
- application=app, use_reloader=options.use_reloader,
- use_debugger=options.use_debugger
+ hostname=(hostname or "127.0.0.1"),
+ port=int(port or 5000),
+ application=app,
+ use_reloader=options.use_reloader,
+ use_debugger=options.use_debugger,
)
-if __name__ == '__main__':
+
+if __name__ == "__main__":
main()
diff --git a/src/werkzeug/test.py b/src/werkzeug/test.py
index 8f04ec21..9aeb4127 100644
--- a/src/werkzeug/test.py
+++ b/src/werkzeug/test.py
@@ -8,56 +8,69 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-import sys
import mimetypes
-from time import time
-
-from random import random
+import sys
+from io import BytesIO
from itertools import chain
+from random import random
from tempfile import TemporaryFile
-from io import BytesIO
+from time import time
+
+from ._compat import iteritems
+from ._compat import iterlists
+from ._compat import itervalues
+from ._compat import make_literal_wrapper
+from ._compat import reraise
+from ._compat import string_types
+from ._compat import text_type
+from ._compat import to_bytes
+from ._compat import wsgi_encoding_dance
+from ._internal import _get_environ
+from .datastructures import CallbackDict
+from .datastructures import CombinedMultiDict
+from .datastructures import EnvironHeaders
+from .datastructures import FileMultiDict
+from .datastructures import FileStorage
+from .datastructures import Headers
+from .datastructures import MultiDict
+from .http import dump_cookie
+from .http import dump_options_header
+from .http import parse_options_header
+from .urls import iri_to_uri
+from .urls import url_encode
+from .urls import url_fix
+from .urls import url_parse
+from .urls import url_unparse
+from .urls import url_unquote
+from .utils import get_content_type
+from .wrappers import BaseRequest
+from .wsgi import ClosingIterator
+from .wsgi import get_current_url
try:
- from urllib2 import Request as U2Request
-except ImportError:
from urllib.request import Request as U2Request
+except ImportError:
+ from urllib2 import Request as U2Request
+
try:
from http.cookiejar import CookieJar
-except ImportError: # Py2
+except ImportError:
from cookielib import CookieJar
-from werkzeug._compat import iterlists, iteritems, itervalues, to_bytes, \
- string_types, text_type, reraise, wsgi_encoding_dance, \
- make_literal_wrapper
-from werkzeug._internal import _get_environ
-from werkzeug.wrappers import BaseRequest
-from werkzeug.urls import url_encode, url_fix, iri_to_uri, url_unquote, \
- url_unparse, url_parse
-from werkzeug.wsgi import get_current_url, ClosingIterator
-from werkzeug.utils import dump_cookie, get_content_type
-from werkzeug.datastructures import (
- FileMultiDict,
- MultiDict,
- CombinedMultiDict,
- Headers,
- FileStorage,
- CallbackDict,
- EnvironHeaders,
-)
-from werkzeug.http import dump_options_header, parse_options_header
-
-
-def stream_encode_multipart(values, use_tempfile=True, threshold=1024 * 500,
- boundary=None, charset='utf-8'):
+
+def stream_encode_multipart(
+ values, use_tempfile=True, threshold=1024 * 500, boundary=None, charset="utf-8"
+):
"""Encode a dict of values (either strings or file descriptors or
:class:`FileStorage` objects.) into a multipart encoded string stored
in a file descriptor.
"""
if boundary is None:
- boundary = '---------------WerkzeugFormPart_%s%s' % (time(), random())
+ boundary = "---------------WerkzeugFormPart_%s%s" % (time(), random())
_closure = [BytesIO(), 0, False]
if use_tempfile:
+
def write_binary(string):
stream, total_length, on_disk = _closure
if on_disk:
@@ -67,12 +80,13 @@ def stream_encode_multipart(values, use_tempfile=True, threshold=1024 * 500,
if length + _closure[1] <= threshold:
stream.write(string)
else:
- new_stream = TemporaryFile('wb+')
+ new_stream = TemporaryFile("wb+")
new_stream.write(stream.getvalue())
new_stream.write(string)
_closure[0] = new_stream
_closure[2] = True
_closure[1] = total_length + length
+
else:
write_binary = _closure[0].write
@@ -84,22 +98,22 @@ def stream_encode_multipart(values, use_tempfile=True, threshold=1024 * 500,
for key, values in iterlists(values):
for value in values:
- write('--%s\r\nContent-Disposition: form-data; name="%s"' %
- (boundary, key))
- reader = getattr(value, 'read', None)
+ write('--%s\r\nContent-Disposition: form-data; name="%s"' % (boundary, key))
+ reader = getattr(value, "read", None)
if reader is not None:
- filename = getattr(value, 'filename',
- getattr(value, 'name', None))
- content_type = getattr(value, 'content_type', None)
+ filename = getattr(value, "filename", getattr(value, "name", None))
+ content_type = getattr(value, "content_type", None)
if content_type is None:
- content_type = filename and \
- mimetypes.guess_type(filename)[0] or \
- 'application/octet-stream'
+ content_type = (
+ filename
+ and mimetypes.guess_type(filename)[0]
+ or "application/octet-stream"
+ )
if filename is not None:
write('; filename="%s"\r\n' % filename)
else:
- write('\r\n')
- write('Content-Type: %s\r\n\r\n' % content_type)
+ write("\r\n")
+ write("Content-Type: %s\r\n\r\n" % content_type)
while 1:
chunk = reader(16384)
if not chunk:
@@ -110,22 +124,23 @@ def stream_encode_multipart(values, use_tempfile=True, threshold=1024 * 500,
value = str(value)
value = to_bytes(value, charset)
- write('\r\n\r\n')
+ write("\r\n\r\n")
write_binary(value)
- write('\r\n')
- write('--%s--\r\n' % boundary)
+ write("\r\n")
+ write("--%s--\r\n" % boundary)
length = int(_closure[0].tell())
_closure[0].seek(0)
return _closure[0], length, boundary
-def encode_multipart(values, boundary=None, charset='utf-8'):
+def encode_multipart(values, boundary=None, charset="utf-8"):
"""Like `stream_encode_multipart` but returns a tuple in the form
(``boundary``, ``data``) where data is a bytestring.
"""
stream, length, boundary = stream_encode_multipart(
- values, use_tempfile=False, boundary=boundary, charset=charset)
+ values, use_tempfile=False, boundary=boundary, charset=charset
+ )
return boundary, stream.read()
@@ -135,6 +150,7 @@ def File(fd, filename=None, mimetype=None):
.. deprecated:: 0.5
"""
from warnings import warn
+
warn(
"'werkzeug.test.File' is deprecated as of version 0.5 and will"
" be removed in version 1.0. Use 'EnvironBuilder' or"
@@ -194,17 +210,16 @@ class _TestCookieJar(CookieJar):
"""
cvals = []
for cookie in self:
- cvals.append('%s=%s' % (cookie.name, cookie.value))
+ cvals.append("%s=%s" % (cookie.name, cookie.value))
if cvals:
- environ['HTTP_COOKIE'] = '; '.join(cvals)
+ environ["HTTP_COOKIE"] = "; ".join(cvals)
def extract_wsgi(self, environ, headers):
"""Extract the server's set-cookie headers as cookies into the
cookie jar.
"""
self.extract_cookies(
- _TestCookieResponse(headers),
- U2Request(get_current_url(environ)),
+ _TestCookieResponse(headers), U2Request(get_current_url(environ))
)
@@ -307,7 +322,7 @@ class EnvironBuilder(object):
"""
#: the server protocol to use. defaults to HTTP/1.1
- server_protocol = 'HTTP/1.1'
+ server_protocol = "HTTP/1.1"
#: the wsgi version to use. defaults to (1, 0)
wsgi_version = (1, 0)
@@ -316,21 +331,37 @@ class EnvironBuilder(object):
request_class = BaseRequest
import json
+
#: The serialization function used when ``json`` is passed.
json_dumps = staticmethod(json.dumps)
del json
- def __init__(self, path='/', base_url=None, query_string=None,
- method='GET', input_stream=None, content_type=None,
- content_length=None, errors_stream=None, multithread=False,
- multiprocess=False, run_once=False, headers=None, data=None,
- environ_base=None, environ_overrides=None, charset='utf-8',
- mimetype=None, json=None):
+ def __init__(
+ self,
+ path="/",
+ base_url=None,
+ query_string=None,
+ method="GET",
+ input_stream=None,
+ content_type=None,
+ content_length=None,
+ errors_stream=None,
+ multithread=False,
+ multiprocess=False,
+ run_once=False,
+ headers=None,
+ data=None,
+ environ_base=None,
+ environ_overrides=None,
+ charset="utf-8",
+ mimetype=None,
+ json=None,
+ ):
path_s = make_literal_wrapper(path)
- if query_string is not None and path_s('?') in path:
- raise ValueError('Query string is defined in the path and as an argument')
- if query_string is None and path_s('?') in path:
- path, query_string = path.split(path_s('?'), 1)
+ if query_string is not None and path_s("?") in path:
+ raise ValueError("Query string is defined in the path and as an argument")
+ if query_string is None and path_s("?") in path:
+ path, query_string = path.split(path_s("?"), 1)
self.charset = charset
self.path = iri_to_uri(path)
if base_url is not None:
@@ -375,8 +406,8 @@ class EnvironBuilder(object):
if data:
if input_stream is not None:
- raise TypeError('can\'t provide input stream and data')
- if hasattr(data, 'read'):
+ raise TypeError("can't provide input stream and data")
+ if hasattr(data, "read"):
data = data.read()
if isinstance(data, text_type):
data = data.encode(self.charset)
@@ -386,8 +417,7 @@ class EnvironBuilder(object):
self.content_length = len(data)
else:
for key, value in _iter_data(data):
- if isinstance(value, (tuple, dict)) or \
- hasattr(value, 'read'):
+ if isinstance(value, (tuple, dict)) or hasattr(value, "read"):
self._add_file_from_data(key, value)
else:
self.form.setlistdefault(key).append(value)
@@ -406,9 +436,7 @@ class EnvironBuilder(object):
out = {
"path": environ["PATH_INFO"],
"base_url": cls._make_base_url(
- environ["wsgi.url_scheme"],
- headers.pop("Host"),
- environ["SCRIPT_NAME"],
+ environ["wsgi.url_scheme"], headers.pop("Host"), environ["SCRIPT_NAME"]
),
"query_string": environ["QUERY_STRING"],
"method": environ["REQUEST_METHOD"],
@@ -430,6 +458,7 @@ class EnvironBuilder(object):
self.files.add_file(key, *value)
elif isinstance(value, dict):
from warnings import warn
+
warn(
"Passing a dict as file data is deprecated as of"
" version 0.5 and will be removed in version 1.0. Use"
@@ -438,9 +467,9 @@ class EnvironBuilder(object):
stacklevel=2,
)
value = dict(value)
- mimetype = value.pop('mimetype', None)
+ mimetype = value.pop("mimetype", None)
if mimetype is not None:
- value['content_type'] = mimetype
+ value["content_type"] = mimetype
self.files.add_file(key, **value)
else:
self.files.add_file(key, value)
@@ -459,90 +488,100 @@ class EnvironBuilder(object):
@base_url.setter
def base_url(self, value):
if value is None:
- scheme = 'http'
- netloc = 'localhost'
- script_root = ''
+ scheme = "http"
+ netloc = "localhost"
+ script_root = ""
else:
scheme, netloc, script_root, qs, anchor = url_parse(value)
if qs or anchor:
- raise ValueError('base url must not contain a query string '
- 'or fragment')
- self.script_root = script_root.rstrip('/')
+ raise ValueError("base url must not contain a query string or fragment")
+ self.script_root = script_root.rstrip("/")
self.host = netloc
self.url_scheme = scheme
def _get_content_type(self):
- ct = self.headers.get('Content-Type')
+ ct = self.headers.get("Content-Type")
if ct is None and not self._input_stream:
if self._files:
- return 'multipart/form-data'
+ return "multipart/form-data"
elif self._form:
- return 'application/x-www-form-urlencoded'
+ return "application/x-www-form-urlencoded"
return None
return ct
def _set_content_type(self, value):
if value is None:
- self.headers.pop('Content-Type', None)
+ self.headers.pop("Content-Type", None)
else:
- self.headers['Content-Type'] = value
-
- content_type = property(_get_content_type, _set_content_type, doc='''
- The content type for the request. Reflected from and to the
- :attr:`headers`. Do not set if you set :attr:`files` or
- :attr:`form` for auto detection.''')
+ self.headers["Content-Type"] = value
+
+ content_type = property(
+ _get_content_type,
+ _set_content_type,
+ doc="""The content type for the request. Reflected from and to
+ the :attr:`headers`. Do not set if you set :attr:`files` or
+ :attr:`form` for auto detection.""",
+ )
del _get_content_type, _set_content_type
def _get_content_length(self):
- return self.headers.get('Content-Length', type=int)
+ return self.headers.get("Content-Length", type=int)
def _get_mimetype(self):
ct = self.content_type
if ct:
- return ct.split(';')[0].strip()
+ return ct.split(";")[0].strip()
def _set_mimetype(self, value):
self.content_type = get_content_type(value, self.charset)
def _get_mimetype_params(self):
def on_update(d):
- self.headers['Content-Type'] = \
- dump_options_header(self.mimetype, d)
- d = parse_options_header(self.headers.get('content-type', ''))[1]
+ self.headers["Content-Type"] = dump_options_header(self.mimetype, d)
+
+ d = parse_options_header(self.headers.get("content-type", ""))[1]
return CallbackDict(d, on_update)
- mimetype = property(_get_mimetype, _set_mimetype, doc='''
- The mimetype (content type without charset etc.)
+ mimetype = property(
+ _get_mimetype,
+ _set_mimetype,
+ doc="""The mimetype (content type without charset etc.)
.. versionadded:: 0.14
- ''')
- mimetype_params = property(_get_mimetype_params, doc='''
- The mimetype parameters as dict. For example if the content
- type is ``text/html; charset=utf-8`` the params would be
+ """,
+ )
+ mimetype_params = property(
+ _get_mimetype_params,
+ doc=""" The mimetype parameters as dict. For example if the
+ content type is ``text/html; charset=utf-8`` the params would be
``{'charset': 'utf-8'}``.
.. versionadded:: 0.14
- ''')
+ """,
+ )
del _get_mimetype, _set_mimetype, _get_mimetype_params
def _set_content_length(self, value):
if value is None:
- self.headers.pop('Content-Length', None)
+ self.headers.pop("Content-Length", None)
else:
- self.headers['Content-Length'] = str(value)
+ self.headers["Content-Length"] = str(value)
- content_length = property(_get_content_length, _set_content_length, doc='''
- The content length as integer. Reflected from and to the
+ content_length = property(
+ _get_content_length,
+ _set_content_length,
+ doc="""The content length as integer. Reflected from and to the
:attr:`headers`. Do not set if you set :attr:`files` or
- :attr:`form` for auto detection.''')
+ :attr:`form` for auto detection.""",
+ )
del _get_content_length, _set_content_length
- def form_property(name, storage, doc):
- key = '_' + name
+ def form_property(name, storage, doc): # noqa: B902
+ key = "_" + name
def getter(self):
if self._input_stream is not None:
- raise AttributeError('an input stream is defined')
+ raise AttributeError("an input stream is defined")
rv = getattr(self, key)
if rv is None:
rv = storage()
@@ -553,14 +592,17 @@ class EnvironBuilder(object):
def setter(self, value):
self._input_stream = None
setattr(self, key, value)
+
return property(getter, setter, doc=doc)
- form = form_property('form', MultiDict, doc='''
- A :class:`MultiDict` of form values.''')
- files = form_property('files', FileMultiDict, doc='''
- A :class:`FileMultiDict` of uploaded files. You can use the
- :meth:`~FileMultiDict.add_file` method to add new files to the
- dict.''')
+ form = form_property("form", MultiDict, doc="A :class:`MultiDict` of form values.")
+ files = form_property(
+ "files",
+ FileMultiDict,
+ doc="""A :class:`FileMultiDict` of uploaded files. You can use
+ the :meth:`~FileMultiDict.add_file` method to add new files to
+ the dict.""",
+ )
del form_property
def _get_input_stream(self):
@@ -570,30 +612,36 @@ class EnvironBuilder(object):
self._input_stream = value
self._form = self._files = None
- input_stream = property(_get_input_stream, _set_input_stream, doc='''
- An optional input stream. If you set this it will clear
- :attr:`form` and :attr:`files`.''')
+ input_stream = property(
+ _get_input_stream,
+ _set_input_stream,
+ doc="""An optional input stream. If you set this it will clear
+ :attr:`form` and :attr:`files`.""",
+ )
del _get_input_stream, _set_input_stream
def _get_query_string(self):
if self._query_string is None:
if self._args is not None:
return url_encode(self._args, charset=self.charset)
- return ''
+ return ""
return self._query_string
def _set_query_string(self, value):
self._query_string = value
self._args = None
- query_string = property(_get_query_string, _set_query_string, doc='''
- The query string. If you set this to a string :attr:`args` will
- no longer be available.''')
+ query_string = property(
+ _get_query_string,
+ _set_query_string,
+ doc="""The query string. If you set this to a string
+ :attr:`args` will no longer be available.""",
+ )
del _get_query_string, _set_query_string
def _get_args(self):
if self._query_string is not None:
- raise AttributeError('a query string is defined')
+ raise AttributeError("a query string is defined")
if self._args is None:
self._args = MultiDict()
return self._args
@@ -602,22 +650,23 @@ class EnvironBuilder(object):
self._query_string = None
self._args = value
- args = property(_get_args, _set_args, doc='''
- The URL arguments as :class:`MultiDict`.''')
+ args = property(
+ _get_args, _set_args, doc="The URL arguments as :class:`MultiDict`."
+ )
del _get_args, _set_args
@property
def server_name(self):
"""The server name (read-only, use :attr:`host` to set)"""
- return self.host.split(':', 1)[0]
+ return self.host.split(":", 1)[0]
@property
def server_port(self):
"""The server port as integer (read-only, use :attr:`host` to set)"""
- pieces = self.host.split(':', 1)
+ pieces = self.host.split(":", 1)
if len(pieces) == 2 and pieces[1].isdigit():
return int(pieces[1])
- elif self.url_scheme == 'https':
+ elif self.url_scheme == "https":
return 443
return 80
@@ -665,15 +714,16 @@ class EnvironBuilder(object):
end_pos = input_stream.tell()
input_stream.seek(start_pos)
content_length = end_pos - start_pos
- elif mimetype == 'multipart/form-data':
+ elif mimetype == "multipart/form-data":
values = CombinedMultiDict([self.form, self.files])
- input_stream, content_length, boundary = \
- stream_encode_multipart(values, charset=self.charset)
+ input_stream, content_length, boundary = stream_encode_multipart(
+ values, charset=self.charset
+ )
content_type = mimetype + '; boundary="%s"' % boundary
- elif mimetype == 'application/x-www-form-urlencoded':
+ elif mimetype == "application/x-www-form-urlencoded":
# XXX: py2v3 review
values = url_encode(self.form, charset=self.charset)
- values = values.encode('ascii')
+ values = values.encode("ascii")
content_length = len(values)
input_stream = BytesIO(values)
else:
@@ -688,40 +738,42 @@ class EnvironBuilder(object):
qs = wsgi_encoding_dance(self.query_string)
- result.update({
- 'REQUEST_METHOD': self.method,
- 'SCRIPT_NAME': _path_encode(self.script_root),
- 'PATH_INFO': _path_encode(self.path),
- 'QUERY_STRING': qs,
- # Non-standard, added by mod_wsgi, uWSGI
- "REQUEST_URI": wsgi_encoding_dance(self.path),
- # Non-standard, added by gunicorn
- "RAW_URI": wsgi_encoding_dance(self.path),
- 'SERVER_NAME': self.server_name,
- 'SERVER_PORT': str(self.server_port),
- 'HTTP_HOST': self.host,
- 'SERVER_PROTOCOL': self.server_protocol,
- 'wsgi.version': self.wsgi_version,
- 'wsgi.url_scheme': self.url_scheme,
- 'wsgi.input': input_stream,
- 'wsgi.errors': self.errors_stream,
- 'wsgi.multithread': self.multithread,
- 'wsgi.multiprocess': self.multiprocess,
- 'wsgi.run_once': self.run_once
- })
+ result.update(
+ {
+ "REQUEST_METHOD": self.method,
+ "SCRIPT_NAME": _path_encode(self.script_root),
+ "PATH_INFO": _path_encode(self.path),
+ "QUERY_STRING": qs,
+ # Non-standard, added by mod_wsgi, uWSGI
+ "REQUEST_URI": wsgi_encoding_dance(self.path),
+ # Non-standard, added by gunicorn
+ "RAW_URI": wsgi_encoding_dance(self.path),
+ "SERVER_NAME": self.server_name,
+ "SERVER_PORT": str(self.server_port),
+ "HTTP_HOST": self.host,
+ "SERVER_PROTOCOL": self.server_protocol,
+ "wsgi.version": self.wsgi_version,
+ "wsgi.url_scheme": self.url_scheme,
+ "wsgi.input": input_stream,
+ "wsgi.errors": self.errors_stream,
+ "wsgi.multithread": self.multithread,
+ "wsgi.multiprocess": self.multiprocess,
+ "wsgi.run_once": self.run_once,
+ }
+ )
headers = self.headers.copy()
if content_type is not None:
- result['CONTENT_TYPE'] = content_type
+ result["CONTENT_TYPE"] = content_type
headers.set("Content-Type", content_type)
if content_length is not None:
- result['CONTENT_LENGTH'] = str(content_length)
+ result["CONTENT_LENGTH"] = str(content_length)
headers.set("Content-Length", content_length)
for key, value in headers.to_wsgi_list():
- result['HTTP_%s' % key.upper().replace('-', '_')] = value
+ result["HTTP_%s" % key.upper().replace("-", "_")] = value
if self.environ_overrides:
result.update(self.environ_overrides)
@@ -740,15 +792,12 @@ class EnvironBuilder(object):
class ClientRedirectError(Exception):
-
- """
- If a redirect loop is detected when using follow_redirects=True with
+ """If a redirect loop is detected when using follow_redirects=True with
the :cls:`Client`, then this exception is raised.
"""
class Client(object):
-
"""This class allows to send requests to a wrapped application.
The response wrapper can be a class or factory function that takes
@@ -781,8 +830,13 @@ class Client(object):
The ``json`` parameter.
"""
- def __init__(self, application, response_wrapper=None, use_cookies=True,
- allow_subdomain_redirects=False):
+ def __init__(
+ self,
+ application,
+ response_wrapper=None,
+ use_cookies=True,
+ allow_subdomain_redirects=False,
+ ):
self.application = application
self.response_wrapper = response_wrapper
if use_cookies:
@@ -791,24 +845,36 @@ class Client(object):
self.cookie_jar = None
self.allow_subdomain_redirects = allow_subdomain_redirects
- def set_cookie(self, server_name, key, value='', max_age=None,
- expires=None, path='/', domain=None, secure=None,
- httponly=False, charset='utf-8'):
+ def set_cookie(
+ self,
+ server_name,
+ key,
+ value="",
+ max_age=None,
+ expires=None,
+ path="/",
+ domain=None,
+ secure=None,
+ httponly=False,
+ charset="utf-8",
+ ):
"""Sets a cookie in the client's cookie jar. The server name
is required and has to match the one that is also passed to
the open call.
"""
- assert self.cookie_jar is not None, 'cookies disabled'
- header = dump_cookie(key, value, max_age, expires, path, domain,
- secure, httponly, charset)
- environ = create_environ(path, base_url='http://' + server_name)
- headers = [('Set-Cookie', header)]
+ assert self.cookie_jar is not None, "cookies disabled"
+ header = dump_cookie(
+ key, value, max_age, expires, path, domain, secure, httponly, charset
+ )
+ environ = create_environ(path, base_url="http://" + server_name)
+ headers = [("Set-Cookie", header)]
self.cookie_jar.extract_wsgi(environ, headers)
- def delete_cookie(self, server_name, key, path='/', domain=None):
+ def delete_cookie(self, server_name, key, path="/", domain=None):
"""Deletes a cookie in the test client."""
- self.set_cookie(server_name, key, expires=0, max_age=0,
- path=path, domain=domain)
+ self.set_cookie(
+ server_name, key, expires=0, max_age=0, path=path, domain=domain
+ )
def run_wsgi_app(self, environ, buffered=False):
"""Runs the wrapped WSGI app with the given environment."""
@@ -826,7 +892,7 @@ class Client(object):
scheme, netloc, path, qs, anchor = url_parse(new_location)
builder = EnvironBuilder.from_environ(environ, query_string=qs)
- to_name_parts = netloc.split(':', 1)[0].split(".")
+ to_name_parts = netloc.split(":", 1)[0].split(".")
from_name_parts = builder.server_name.split(".")
if to_name_parts != [""]:
@@ -840,7 +906,7 @@ class Client(object):
# Explain why a redirect to a different server name won't be followed.
if to_name_parts != from_name_parts:
- if to_name_parts[-len(from_name_parts):] == from_name_parts:
+ if to_name_parts[-len(from_name_parts) :] == from_name_parts:
if not self.allow_subdomain_redirects:
raise RuntimeError("Following subdomain redirects is not enabled.")
else:
@@ -849,9 +915,9 @@ class Client(object):
path_parts = path.split("/")
root_parts = builder.script_root.split("/")
- if path_parts[:len(root_parts)] == root_parts:
+ if path_parts[: len(root_parts)] == root_parts:
# Strip the script root from the path.
- builder.path = path[len(builder.script_root):]
+ builder.path = path[len(builder.script_root) :]
else:
# The new location is not under the script root, so use the
# whole path and clear the previous root.
@@ -907,9 +973,9 @@ class Client(object):
:param follow_redirects: Set this to True if the `Client` should
follow HTTP redirects.
"""
- as_tuple = kwargs.pop('as_tuple', False)
- buffered = kwargs.pop('buffered', False)
- follow_redirects = kwargs.pop('follow_redirects', False)
+ as_tuple = kwargs.pop("as_tuple", False)
+ buffered = kwargs.pop("buffered", False)
+ follow_redirects = kwargs.pop("follow_redirects", False)
environ = None
if not kwargs and len(args) == 1:
if isinstance(args[0], EnvironBuilder):
@@ -941,15 +1007,13 @@ class Client(object):
for _ in response[0]:
pass
- new_location = response[2]['location']
+ new_location = response[2]["location"]
new_redirect_entry = (new_location, status_code)
if new_redirect_entry in redirect_chain:
- raise ClientRedirectError('loop detected')
+ raise ClientRedirectError("loop detected")
redirect_chain.append(new_redirect_entry)
environ, response = self.resolve_redirect(
- response,
- new_location,
- environ, buffered=buffered
+ response, new_location, environ, buffered=buffered
)
if self.response_wrapper is not None:
@@ -960,49 +1024,46 @@ class Client(object):
def get(self, *args, **kw):
"""Like open but method is enforced to GET."""
- kw['method'] = 'GET'
+ kw["method"] = "GET"
return self.open(*args, **kw)
def patch(self, *args, **kw):
"""Like open but method is enforced to PATCH."""
- kw['method'] = 'PATCH'
+ kw["method"] = "PATCH"
return self.open(*args, **kw)
def post(self, *args, **kw):
"""Like open but method is enforced to POST."""
- kw['method'] = 'POST'
+ kw["method"] = "POST"
return self.open(*args, **kw)
def head(self, *args, **kw):
"""Like open but method is enforced to HEAD."""
- kw['method'] = 'HEAD'
+ kw["method"] = "HEAD"
return self.open(*args, **kw)
def put(self, *args, **kw):
"""Like open but method is enforced to PUT."""
- kw['method'] = 'PUT'
+ kw["method"] = "PUT"
return self.open(*args, **kw)
def delete(self, *args, **kw):
"""Like open but method is enforced to DELETE."""
- kw['method'] = 'DELETE'
+ kw["method"] = "DELETE"
return self.open(*args, **kw)
def options(self, *args, **kw):
"""Like open but method is enforced to OPTIONS."""
- kw['method'] = 'OPTIONS'
+ kw["method"] = "OPTIONS"
return self.open(*args, **kw)
def trace(self, *args, **kw):
"""Like open but method is enforced to TRACE."""
- kw['method'] = 'TRACE'
+ kw["method"] = "TRACE"
return self.open(*args, **kw)
def __repr__(self):
- return '<%s %r>' % (
- self.__class__.__name__,
- self.application
- )
+ return "<%s %r>" % (self.__class__.__name__, self.application)
def create_environ(*args, **kwargs):
@@ -1055,7 +1116,7 @@ def run_wsgi_app(app, environ, buffered=False):
return buffer.append
app_rv = app(environ, start_response)
- close_func = getattr(app_rv, 'close', None)
+ close_func = getattr(app_rv, "close", None)
app_iter = iter(app_rv)
# when buffering we emit the close call early and convert the
diff --git a/src/werkzeug/testapp.py b/src/werkzeug/testapp.py
index 972c6ca1..8ea23bee 100644
--- a/src/werkzeug/testapp.py
+++ b/src/werkzeug/testapp.py
@@ -9,15 +9,19 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
+import base64
import os
import sys
-import werkzeug
from textwrap import wrap
-from werkzeug.wrappers import BaseRequest as Request, BaseResponse as Response
-from werkzeug.utils import escape
-import base64
-logo = Response(base64.b64decode('''
+import werkzeug
+from .utils import escape
+from .wrappers import BaseRequest as Request
+from .wrappers import BaseResponse as Response
+
+logo = Response(
+ base64.b64decode(
+ """
R0lGODlhoACgAOMIAAEDACwpAEpCAGdgAJaKAM28AOnVAP3rAP/////////
//////////////////////yH5BAEKAAgALAAAAACgAKAAAAT+EMlJq704680R+F0ojmRpnuj0rWnrv
nB8rbRs33gu0bzu/0AObxgsGn3D5HHJbCUFyqZ0ukkSDlAidctNFg7gbI9LZlrBaHGtzAae0eloe25
@@ -51,10 +55,13 @@ UpyGlhjBUljyjHhWpf8OFaXwhp9O4T1gU9UeyPPa8A2l0p1kNqPXEVRm1AOs1oAGZU596t6SOR2mcB
Oco1srWtkaVrMUzIErrKri85keKqRQYX9VX0/eAUK1hrSu6HMEX3Qh2sCh0q0D2CtnUqS4hj62sE/z
aDs2Sg7MBS6xnQeooc2R2tC9YrKpEi9pLXfYXp20tDCpSP8rKlrD4axprb9u1Df5hSbz9QU0cRpfgn
kiIzwKucd0wsEHlLpe5yHXuc6FrNelOl7pY2+11kTWx7VpRu97dXA3DO1vbkhcb4zyvERYajQgAADs
-='''), mimetype='image/png')
+="""
+ ),
+ mimetype="image/png",
+)
-TEMPLATE = u'''\
+TEMPLATE = u"""\
<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN"
"http://www.w3.org/TR/html4/loose.dtd">
<title>WSGI Information</title>
@@ -130,24 +137,27 @@ TEMPLATE = u'''\
environments.
<ul class="path">%(sys_path)s</ul>
</div>
-'''
+"""
def iter_sys_path():
- if os.name == 'posix':
+ if os.name == "posix":
+
def strip(x):
- prefix = os.path.expanduser('~')
+ prefix = os.path.expanduser("~")
if x.startswith(prefix):
- x = '~' + x[len(prefix):]
+ x = "~" + x[len(prefix) :]
return x
+
else:
- strip = lambda x: x
+
+ def strip(x):
+ return x
cwd = os.path.abspath(os.getcwd())
for item in sys.path:
path = os.path.join(cwd, item or os.path.curdir)
- yield strip(os.path.normpath(path)), \
- not os.path.isdir(path), path != item
+ yield strip(os.path.normpath(path)), not os.path.isdir(path), path != item
def render_testapp(req):
@@ -156,51 +166,51 @@ def render_testapp(req):
except ImportError:
eggs = ()
else:
- eggs = sorted(pkg_resources.working_set,
- key=lambda x: x.project_name.lower())
+ eggs = sorted(pkg_resources.working_set, key=lambda x: x.project_name.lower())
python_eggs = []
for egg in eggs:
try:
version = egg.version
except (ValueError, AttributeError):
- version = 'unknown'
- python_eggs.append('<li>%s <small>[%s]</small>' % (
- escape(egg.project_name),
- escape(version)
- ))
+ version = "unknown"
+ python_eggs.append(
+ "<li>%s <small>[%s]</small>" % (escape(egg.project_name), escape(version))
+ )
wsgi_env = []
- sorted_environ = sorted(req.environ.items(),
- key=lambda x: repr(x[0]).lower())
+ sorted_environ = sorted(req.environ.items(), key=lambda x: repr(x[0]).lower())
for key, value in sorted_environ:
- wsgi_env.append('<tr><th>%s<td><code>%s</code>' % (
- escape(str(key)),
- ' '.join(wrap(escape(repr(value))))
- ))
+ wsgi_env.append(
+ "<tr><th>%s<td><code>%s</code>"
+ % (escape(str(key)), " ".join(wrap(escape(repr(value)))))
+ )
sys_path = []
for item, virtual, expanded in iter_sys_path():
class_ = []
if virtual:
- class_.append('virtual')
+ class_.append("virtual")
if expanded:
- class_.append('exp')
- sys_path.append('<li%s>%s' % (
- ' class="%s"' % ' '.join(class_) if class_ else '',
- escape(item)
- ))
-
- return (TEMPLATE % {
- 'python_version': '<br>'.join(escape(sys.version).splitlines()),
- 'platform': escape(sys.platform),
- 'os': escape(os.name),
- 'api_version': sys.api_version,
- 'byteorder': sys.byteorder,
- 'werkzeug_version': werkzeug.__version__,
- 'python_eggs': '\n'.join(python_eggs),
- 'wsgi_env': '\n'.join(wsgi_env),
- 'sys_path': '\n'.join(sys_path)
- }).encode('utf-8')
+ class_.append("exp")
+ sys_path.append(
+ "<li%s>%s"
+ % (' class="%s"' % " ".join(class_) if class_ else "", escape(item))
+ )
+
+ return (
+ TEMPLATE
+ % {
+ "python_version": "<br>".join(escape(sys.version).splitlines()),
+ "platform": escape(sys.platform),
+ "os": escape(os.name),
+ "api_version": sys.api_version,
+ "byteorder": sys.byteorder,
+ "werkzeug_version": werkzeug.__version__,
+ "python_eggs": "\n".join(python_eggs),
+ "wsgi_env": "\n".join(wsgi_env),
+ "sys_path": "\n".join(sys_path),
+ }
+ ).encode("utf-8")
def test_app(environ, start_response):
@@ -218,13 +228,14 @@ def test_app(environ, start_response):
the Python interpreter and the installed libraries.
"""
req = Request(environ, populate_request=False)
- if req.args.get('resource') == 'logo':
+ if req.args.get("resource") == "logo":
response = logo
else:
- response = Response(render_testapp(req), mimetype='text/html')
+ response = Response(render_testapp(req), mimetype="text/html")
return response(environ, start_response)
-if __name__ == '__main__':
- from werkzeug.serving import run_simple
- run_simple('localhost', 5000, test_app, use_reloader=True)
+if __name__ == "__main__":
+ from .serving import run_simple
+
+ run_simple("localhost", 5000, test_app, use_reloader=True)
diff --git a/src/werkzeug/urls.py b/src/werkzeug/urls.py
index 09ec5b0c..38e9e5ad 100644
--- a/src/werkzeug/urls.py
+++ b/src/werkzeug/urls.py
@@ -15,56 +15,53 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-from collections import namedtuple
-
import codecs
import os
import re
+from collections import namedtuple
-from werkzeug._compat import (
- PY2,
- fix_tuple_repr,
- implements_to_string,
- make_literal_wrapper,
- normalize_string_tuple,
- text_type,
- to_native,
- to_unicode,
- try_coerce_native,
-)
-from werkzeug._internal import _decode_idna, _encode_idna
-from werkzeug.datastructures import MultiDict, iter_multi_items
+from ._compat import fix_tuple_repr
+from ._compat import implements_to_string
+from ._compat import make_literal_wrapper
+from ._compat import normalize_string_tuple
+from ._compat import PY2
+from ._compat import text_type
+from ._compat import to_native
+from ._compat import to_unicode
+from ._compat import try_coerce_native
+from ._internal import _decode_idna
+from ._internal import _encode_idna
+from .datastructures import iter_multi_items
+from .datastructures import MultiDict
# A regular expression for what a valid schema looks like
-_scheme_re = re.compile(r'^[a-zA-Z0-9+-.]+$')
+_scheme_re = re.compile(r"^[a-zA-Z0-9+-.]+$")
# Characters that are safe in any part of an URL.
-_always_safe = frozenset(bytearray(
- b"abcdefghijklmnopqrstuvwxyz"
- b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
- b"0123456789"
- b"-._~"
-))
-
-_hexdigits = '0123456789ABCDEFabcdef'
+_always_safe = frozenset(
+ bytearray(
+ b"abcdefghijklmnopqrstuvwxyz"
+ b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ b"0123456789"
+ b"-._~"
+ )
+)
+
+_hexdigits = "0123456789ABCDEFabcdef"
_hextobyte = dict(
- ((a + b).encode(), int(a + b, 16))
- for a in _hexdigits for b in _hexdigits
+ ((a + b).encode(), int(a + b, 16)) for a in _hexdigits for b in _hexdigits
)
-_bytetohex = [
- ('%%%02X' % char).encode('ascii') for char in range(256)
-]
+_bytetohex = [("%%%02X" % char).encode("ascii") for char in range(256)]
-_URLTuple = fix_tuple_repr(namedtuple(
- '_URLTuple',
- ['scheme', 'netloc', 'path', 'query', 'fragment']
-))
+_URLTuple = fix_tuple_repr(
+ namedtuple("_URLTuple", ["scheme", "netloc", "path", "query", "fragment"])
+)
class BaseURL(_URLTuple):
+ """Superclass of :py:class:`URL` and :py:class:`BytesURL`."""
- '''Superclass of :py:class:`URL` and :py:class:`BytesURL`.'''
__slots__ = ()
def replace(self, **kwargs):
@@ -92,8 +89,8 @@ class BaseURL(_URLTuple):
try:
rv = _encode_idna(rv)
except UnicodeError:
- rv = rv.encode('ascii', 'ignore')
- return to_native(rv, 'ascii', 'ignore')
+ rv = rv.encode("ascii", "ignore")
+ return to_native(rv, "ascii", "ignore")
@property
def port(self):
@@ -169,19 +166,24 @@ class BaseURL(_URLTuple):
def decode_netloc(self):
"""Decodes the netloc part into a string."""
- rv = _decode_idna(self.host or '')
+ rv = _decode_idna(self.host or "")
- if ':' in rv:
- rv = '[%s]' % rv
+ if ":" in rv:
+ rv = "[%s]" % rv
port = self.port
if port is not None:
- rv = '%s:%d' % (rv, port)
- auth = ':'.join(filter(None, [
- _url_unquote_legacy(self.raw_username or '', '/:%@'),
- _url_unquote_legacy(self.raw_password or '', '/:%@'),
- ]))
+ rv = "%s:%d" % (rv, port)
+ auth = ":".join(
+ filter(
+ None,
+ [
+ _url_unquote_legacy(self.raw_username or "", "/:%@"),
+ _url_unquote_legacy(self.raw_password or "", "/:%@"),
+ ],
+ )
+ )
if auth:
- rv = '%s@%s' % (auth, rv)
+ rv = "%s@%s" % (auth, rv)
return rv
def to_uri_tuple(self):
@@ -192,7 +194,7 @@ class BaseURL(_URLTuple):
It's usually more interesting to directly call :meth:`iri_to_uri` which
will return a string.
"""
- return url_parse(iri_to_uri(self).encode('ascii'))
+ return url_parse(iri_to_uri(self).encode("ascii"))
def to_iri_tuple(self):
"""Returns a :class:`URL` tuple that holds a IRI. This will try
@@ -223,42 +225,44 @@ class BaseURL(_URLTuple):
supported. Defaults to ``None`` which is
autodetect.
"""
- if self.scheme != 'file':
+ if self.scheme != "file":
return None, None
path = url_unquote(self.path)
host = self.netloc or None
if pathformat is None:
- if os.name == 'nt':
- pathformat = 'windows'
+ if os.name == "nt":
+ pathformat = "windows"
else:
- pathformat = 'posix'
+ pathformat = "posix"
- if pathformat == 'windows':
- if path[:1] == '/' and path[1:2].isalpha() and path[2:3] in '|:':
- path = path[1:2] + ':' + path[3:]
- windows_share = path[:3] in ('\\' * 3, '/' * 3)
+ if pathformat == "windows":
+ if path[:1] == "/" and path[1:2].isalpha() and path[2:3] in "|:":
+ path = path[1:2] + ":" + path[3:]
+ windows_share = path[:3] in ("\\" * 3, "/" * 3)
import ntpath
+
path = ntpath.normpath(path)
# Windows shared drives are represented as ``\\host\\directory``.
# That results in a URL like ``file://///host/directory``, and a
# path like ``///host/directory``. We need to special-case this
# because the path contains the hostname.
if windows_share and host is None:
- parts = path.lstrip('\\').split('\\', 1)
+ parts = path.lstrip("\\").split("\\", 1)
if len(parts) == 2:
host, path = parts
else:
host = parts[0]
- path = ''
- elif pathformat == 'posix':
+ path = ""
+ elif pathformat == "posix":
import posixpath
+
path = posixpath.normpath(path)
else:
- raise TypeError('Invalid path format %s' % repr(pathformat))
+ raise TypeError("Invalid path format %s" % repr(pathformat))
- if host in ('127.0.0.1', '::1', 'localhost'):
+ if host in ("127.0.0.1", "::1", "localhost"):
host = None
return host, path
@@ -291,7 +295,7 @@ class BaseURL(_URLTuple):
return rv, None
host = rv[1:idx]
- rest = rv[idx + 1:]
+ rest = rv[idx + 1 :]
if rest.startswith(self._colon):
return host, rest[1:]
return host, None
@@ -299,81 +303,84 @@ class BaseURL(_URLTuple):
@implements_to_string
class URL(BaseURL):
-
"""Represents a parsed URL. This behaves like a regular tuple but
also has some extra attributes that give further insight into the
URL.
"""
+
__slots__ = ()
- _at = '@'
- _colon = ':'
- _lbracket = '['
- _rbracket = ']'
+ _at = "@"
+ _colon = ":"
+ _lbracket = "["
+ _rbracket = "]"
def __str__(self):
return self.to_url()
def encode_netloc(self):
"""Encodes the netloc part to an ASCII safe URL as bytes."""
- rv = self.ascii_host or ''
- if ':' in rv:
- rv = '[%s]' % rv
+ rv = self.ascii_host or ""
+ if ":" in rv:
+ rv = "[%s]" % rv
port = self.port
if port is not None:
- rv = '%s:%d' % (rv, port)
- auth = ':'.join(filter(None, [
- url_quote(self.raw_username or '', 'utf-8', 'strict', '/:%'),
- url_quote(self.raw_password or '', 'utf-8', 'strict', '/:%'),
- ]))
+ rv = "%s:%d" % (rv, port)
+ auth = ":".join(
+ filter(
+ None,
+ [
+ url_quote(self.raw_username or "", "utf-8", "strict", "/:%"),
+ url_quote(self.raw_password or "", "utf-8", "strict", "/:%"),
+ ],
+ )
+ )
if auth:
- rv = '%s@%s' % (auth, rv)
+ rv = "%s@%s" % (auth, rv)
return to_native(rv)
- def encode(self, charset='utf-8', errors='replace'):
+ def encode(self, charset="utf-8", errors="replace"):
"""Encodes the URL to a tuple made out of bytes. The charset is
only being used for the path, query and fragment.
"""
return BytesURL(
- self.scheme.encode('ascii'),
+ self.scheme.encode("ascii"),
self.encode_netloc(),
self.path.encode(charset, errors),
self.query.encode(charset, errors),
- self.fragment.encode(charset, errors)
+ self.fragment.encode(charset, errors),
)
class BytesURL(BaseURL):
-
"""Represents a parsed URL in bytes."""
+
__slots__ = ()
- _at = b'@'
- _colon = b':'
- _lbracket = b'['
- _rbracket = b']'
+ _at = b"@"
+ _colon = b":"
+ _lbracket = b"["
+ _rbracket = b"]"
def __str__(self):
- return self.to_url().decode('utf-8', 'replace')
+ return self.to_url().decode("utf-8", "replace")
def encode_netloc(self):
"""Returns the netloc unchanged as bytes."""
return self.netloc
- def decode(self, charset='utf-8', errors='replace'):
+ def decode(self, charset="utf-8", errors="replace"):
"""Decodes the URL to a tuple made out of strings. The charset is
only being used for the path, query and fragment.
"""
return URL(
- self.scheme.decode('ascii'),
+ self.scheme.decode("ascii"),
self.decode_netloc(),
self.path.decode(charset, errors),
self.query.decode(charset, errors),
- self.fragment.decode(charset, errors)
+ self.fragment.decode(charset, errors),
)
-_unquote_maps = {
- frozenset(): _hextobyte,
-}
+_unquote_maps = {frozenset(): _hextobyte}
def _unquote_to_bytes(string, unsafe=""):
@@ -418,15 +425,14 @@ def _url_encode_impl(obj, charset, encode_keys, sort, key):
key = text_type(key).encode(charset)
if not isinstance(value, bytes):
value = text_type(value).encode(charset)
- yield _fast_url_quote_plus(key) + '=' + _fast_url_quote_plus(value)
+ yield _fast_url_quote_plus(key) + "=" + _fast_url_quote_plus(value)
-def _url_unquote_legacy(value, unsafe=''):
+def _url_unquote_legacy(value, unsafe=""):
try:
- return url_unquote(value, charset='utf-8',
- errors='strict', unsafe=unsafe)
+ return url_unquote(value, charset="utf-8", errors="strict", unsafe=unsafe)
except UnicodeError:
- return url_unquote(value, charset='latin1', unsafe=unsafe)
+ return url_unquote(value, charset="latin1", unsafe=unsafe)
def url_parse(url, scheme=None, allow_fragments=True):
@@ -446,38 +452,39 @@ def url_parse(url, scheme=None, allow_fragments=True):
is_text_based = isinstance(url, text_type)
if scheme is None:
- scheme = s('')
- netloc = query = fragment = s('')
- i = url.find(s(':'))
- if i > 0 and _scheme_re.match(to_native(url[:i], errors='replace')):
+ scheme = s("")
+ netloc = query = fragment = s("")
+ i = url.find(s(":"))
+ if i > 0 and _scheme_re.match(to_native(url[:i], errors="replace")):
# make sure "iri" is not actually a port number (in which case
# "scheme" is really part of the path)
- rest = url[i + 1:]
- if not rest or any(c not in s('0123456789') for c in rest):
+ rest = url[i + 1 :]
+ if not rest or any(c not in s("0123456789") for c in rest):
# not a port number
scheme, url = url[:i].lower(), rest
- if url[:2] == s('//'):
+ if url[:2] == s("//"):
delim = len(url)
- for c in s('/?#'):
+ for c in s("/?#"):
wdelim = url.find(c, 2)
if wdelim >= 0:
delim = min(delim, wdelim)
netloc, url = url[2:delim], url[delim:]
- if (s('[') in netloc and s(']') not in netloc) or \
- (s(']') in netloc and s('[') not in netloc):
- raise ValueError('Invalid IPv6 URL')
+ if (s("[") in netloc and s("]") not in netloc) or (
+ s("]") in netloc and s("[") not in netloc
+ ):
+ raise ValueError("Invalid IPv6 URL")
- if allow_fragments and s('#') in url:
- url, fragment = url.split(s('#'), 1)
- if s('?') in url:
- url, query = url.split(s('?'), 1)
+ if allow_fragments and s("#") in url:
+ url, fragment = url.split(s("#"), 1)
+ if s("?") in url:
+ url, query = url.split(s("?"), 1)
result_type = URL if is_text_based else BytesURL
return result_type(scheme, netloc, url, query, fragment)
-def _make_fast_url_quote(charset='utf-8', errors='strict', safe='/:', unsafe=''):
+def _make_fast_url_quote(charset="utf-8", errors="strict", safe="/:", unsafe=""):
"""Precompile the translation table for a URL encoding function.
Unlike :func:`url_quote`, the generated function only takes the
@@ -495,12 +502,15 @@ def _make_fast_url_quote(charset='utf-8', errors='strict', safe='/:', unsafe='')
unsafe = unsafe.encode(charset, errors)
safe = (frozenset(bytearray(safe)) | _always_safe) - frozenset(bytearray(unsafe))
- table = [chr(c) if c in safe else '%%%02X' % c for c in range(256)]
+ table = [chr(c) if c in safe else "%%%02X" % c for c in range(256)]
if not PY2:
+
def quote(string):
return "".join([table[c] for c in string])
+
else:
+
def quote(string):
return "".join([table[c] for c in bytearray(string)])
@@ -508,14 +518,14 @@ def _make_fast_url_quote(charset='utf-8', errors='strict', safe='/:', unsafe='')
_fast_url_quote = _make_fast_url_quote()
-_fast_quote_plus = _make_fast_url_quote(safe=' ', unsafe='+')
+_fast_quote_plus = _make_fast_url_quote(safe=" ", unsafe="+")
def _fast_url_quote_plus(string):
- return _fast_quote_plus(string).replace(' ', '+')
+ return _fast_quote_plus(string).replace(" ", "+")
-def url_quote(string, charset='utf-8', errors='strict', safe='/:', unsafe=''):
+def url_quote(string, charset="utf-8", errors="strict", safe="/:", unsafe=""):
"""URL encode a single string with a given encoding.
:param s: the string to quote.
@@ -544,7 +554,7 @@ def url_quote(string, charset='utf-8', errors='strict', safe='/:', unsafe=''):
return to_native(bytes(rv))
-def url_quote_plus(string, charset='utf-8', errors='strict', safe=''):
+def url_quote_plus(string, charset="utf-8", errors="strict", safe=""):
"""URL encode a single string with the given encoding and convert
whitespace to "+".
@@ -552,7 +562,7 @@ def url_quote_plus(string, charset='utf-8', errors='strict', safe=''):
:param charset: The charset to be used.
:param safe: An optional sequence of safe characters.
"""
- return url_quote(string, charset, errors, safe + ' ', '+').replace(' ', '+')
+ return url_quote(string, charset, errors, safe + " ", "+").replace(" ", "+")
def url_unparse(components):
@@ -562,31 +572,30 @@ def url_unparse(components):
:param components: the parsed URL as tuple which should be converted
into a URL string.
"""
- scheme, netloc, path, query, fragment = \
- normalize_string_tuple(components)
+ scheme, netloc, path, query, fragment = normalize_string_tuple(components)
s = make_literal_wrapper(scheme)
- url = s('')
+ url = s("")
# We generally treat file:///x and file:/x the same which is also
# what browsers seem to do. This also allows us to ignore a schema
# register for netloc utilization or having to differenciate between
# empty and missing netloc.
- if netloc or (scheme and path.startswith(s('/'))):
- if path and path[:1] != s('/'):
- path = s('/') + path
- url = s('//') + (netloc or s('')) + path
+ if netloc or (scheme and path.startswith(s("/"))):
+ if path and path[:1] != s("/"):
+ path = s("/") + path
+ url = s("//") + (netloc or s("")) + path
elif path:
url += path
if scheme:
- url = scheme + s(':') + url
+ url = scheme + s(":") + url
if query:
- url = url + s('?') + query
+ url = url + s("?") + query
if fragment:
- url = url + s('#') + fragment
+ url = url + s("#") + fragment
return url
-def url_unquote(string, charset='utf-8', errors='replace', unsafe=''):
+def url_unquote(string, charset="utf-8", errors="replace", unsafe=""):
"""URL decode a single string with a given encoding. If the charset
is set to `None` no unicode decoding is performed and raw bytes
are returned.
@@ -602,7 +611,7 @@ def url_unquote(string, charset='utf-8', errors='replace', unsafe=''):
return rv
-def url_unquote_plus(s, charset='utf-8', errors='replace'):
+def url_unquote_plus(s, charset="utf-8", errors="replace"):
"""URL decode a single string with the given `charset` and decode "+" to
whitespace.
@@ -616,13 +625,13 @@ def url_unquote_plus(s, charset='utf-8', errors='replace'):
:param errors: The error handling for the `charset` decoding.
"""
if isinstance(s, text_type):
- s = s.replace(u'+', u' ')
+ s = s.replace(u"+", u" ")
else:
- s = s.replace(b'+', b' ')
+ s = s.replace(b"+", b" ")
return url_unquote(s, charset, errors)
-def url_fix(s, charset='utf-8'):
+def url_fix(s, charset="utf-8"):
r"""Sometimes you get an URL by a user that just isn't a real URL because
it contains unsafe characters like ' ' and so on. This function can fix
some of the problems in a similar way browsers handle data entered by the
@@ -638,19 +647,18 @@ def url_fix(s, charset='utf-8'):
# First step is to switch to unicode processing and to convert
# backslashes (which are invalid in URLs anyways) to slashes. This is
# consistent with what Chrome does.
- s = to_unicode(s, charset, 'replace').replace('\\', '/')
+ s = to_unicode(s, charset, "replace").replace("\\", "/")
# For the specific case that we look like a malformed windows URL
# we want to fix this up manually:
- if s.startswith('file://') and s[7:8].isalpha() and s[8:10] in (':/', '|/'):
- s = 'file:///' + s[7:]
+ if s.startswith("file://") and s[7:8].isalpha() and s[8:10] in (":/", "|/"):
+ s = "file:///" + s[7:]
url = url_parse(s)
- path = url_quote(url.path, charset, safe='/%+$!*\'(),')
- qs = url_quote_plus(url.query, charset, safe=':&%=+$!*\'(),')
- anchor = url_quote_plus(url.fragment, charset, safe=':&%=+$!*\'(),')
- return to_native(url_unparse((url.scheme, url.encode_netloc(),
- path, qs, anchor)))
+ path = url_quote(url.path, charset, safe="/%+$!*'(),")
+ qs = url_quote_plus(url.query, charset, safe=":&%=+$!*'(),")
+ anchor = url_quote_plus(url.fragment, charset, safe=":&%=+$!*'(),")
+ return to_native(url_unparse((url.scheme, url.encode_netloc(), path, qs, anchor)))
# not-unreserved characters remain quoted when unquoting to IRI
@@ -661,7 +669,7 @@ def _codec_error_url_quote(e):
"""Used in :func:`uri_to_iri` after unquoting to re-quote any
invalid bytes.
"""
- out = _fast_url_quote(e.object[e.start:e.end])
+ out = _fast_url_quote(e.object[e.start : e.end])
if PY2:
out = out.decode("utf-8")
@@ -752,7 +760,7 @@ def iri_to_uri(iri, charset="utf-8", errors="strict", safe_conversion=False):
# contains ASCII characters, return it unconverted.
try:
native_iri = to_native(iri)
- ascii_iri = native_iri.encode('ascii')
+ ascii_iri = native_iri.encode("ascii")
# Only return if it doesn't have whitespace. (Why?)
if len(ascii_iri.split()) == 1:
@@ -769,8 +777,15 @@ def iri_to_uri(iri, charset="utf-8", errors="strict", safe_conversion=False):
)
-def url_decode(s, charset='utf-8', decode_keys=False, include_empty=True,
- errors='replace', separator='&', cls=None):
+def url_decode(
+ s,
+ charset="utf-8",
+ decode_keys=False,
+ include_empty=True,
+ errors="replace",
+ separator="&",
+ cls=None,
+):
"""
Parse a querystring and return it as :class:`MultiDict`. There is a
difference in key decoding on different Python versions. On Python 3
@@ -812,16 +827,27 @@ def url_decode(s, charset='utf-8', decode_keys=False, include_empty=True,
if cls is None:
cls = MultiDict
if isinstance(s, text_type) and not isinstance(separator, text_type):
- separator = separator.decode(charset or 'ascii')
+ separator = separator.decode(charset or "ascii")
elif isinstance(s, bytes) and not isinstance(separator, bytes):
- separator = separator.encode(charset or 'ascii')
- return cls(_url_decode_impl(s.split(separator), charset, decode_keys,
- include_empty, errors))
+ separator = separator.encode(charset or "ascii")
+ return cls(
+ _url_decode_impl(
+ s.split(separator), charset, decode_keys, include_empty, errors
+ )
+ )
-def url_decode_stream(stream, charset='utf-8', decode_keys=False,
- include_empty=True, errors='replace', separator='&',
- cls=None, limit=None, return_iterator=False):
+def url_decode_stream(
+ stream,
+ charset="utf-8",
+ decode_keys=False,
+ include_empty=True,
+ errors="replace",
+ separator="&",
+ cls=None,
+ limit=None,
+ return_iterator=False,
+):
"""Works like :func:`url_decode` but decodes a stream. The behavior
of stream and limit follows functions like
:func:`~werkzeug.wsgi.make_line_iter`. The generator of pairs is
@@ -849,14 +875,18 @@ def url_decode_stream(stream, charset='utf-8', decode_keys=False,
and an iterator over all decoded pairs is
returned
"""
- from werkzeug.wsgi import make_chunk_iter
+ from .wsgi import make_chunk_iter
+
+ pair_iter = make_chunk_iter(stream, separator, limit)
+ decoder = _url_decode_impl(pair_iter, charset, decode_keys, include_empty, errors)
+
if return_iterator:
- cls = lambda x: x
- elif cls is None:
+ return decoder
+
+ if cls is None:
cls = MultiDict
- pair_iter = make_chunk_iter(stream, separator, limit)
- return cls(_url_decode_impl(pair_iter, charset, decode_keys,
- include_empty, errors))
+
+ return cls(decoder)
def _url_decode_impl(pair_iter, charset, decode_keys, include_empty, errors):
@@ -864,22 +894,23 @@ def _url_decode_impl(pair_iter, charset, decode_keys, include_empty, errors):
if not pair:
continue
s = make_literal_wrapper(pair)
- equal = s('=')
+ equal = s("=")
if equal in pair:
key, value = pair.split(equal, 1)
else:
if not include_empty:
continue
key = pair
- value = s('')
+ value = s("")
key = url_unquote_plus(key, charset, errors)
if charset is not None and PY2 and not decode_keys:
key = try_coerce_native(key)
yield key, url_unquote_plus(value, charset, errors)
-def url_encode(obj, charset='utf-8', encode_keys=False, sort=False, key=None,
- separator=b'&'):
+def url_encode(
+ obj, charset="utf-8", encode_keys=False, sort=False, key=None, separator=b"&"
+):
"""URL encode a dict/`MultiDict`. If a value is `None` it will not appear
in the result string. Per default only values are encoded into the target
charset strings. If `encode_keys` is set to ``True`` unicode keys are
@@ -900,12 +931,19 @@ def url_encode(obj, charset='utf-8', encode_keys=False, sort=False, key=None,
:param key: an optional function to be used for sorting. For more details
check out the :func:`sorted` documentation.
"""
- separator = to_native(separator, 'ascii')
+ separator = to_native(separator, "ascii")
return separator.join(_url_encode_impl(obj, charset, encode_keys, sort, key))
-def url_encode_stream(obj, stream=None, charset='utf-8', encode_keys=False,
- sort=False, key=None, separator=b'&'):
+def url_encode_stream(
+ obj,
+ stream=None,
+ charset="utf-8",
+ encode_keys=False,
+ sort=False,
+ key=None,
+ separator=b"&",
+):
"""Like :meth:`url_encode` but writes the results to a stream
object. If the stream is `None` a generator over all encoded
pairs is returned.
@@ -924,7 +962,7 @@ def url_encode_stream(obj, stream=None, charset='utf-8', encode_keys=False,
:param key: an optional function to be used for sorting. For more details
check out the :func:`sorted` documentation.
"""
- separator = to_native(separator, 'ascii')
+ separator = to_native(separator, "ascii")
gen = _url_encode_impl(obj, charset, encode_keys, sort, key)
if stream is None:
return gen
@@ -955,55 +993,53 @@ def url_join(base, url, allow_fragments=True):
if not url:
return base
- bscheme, bnetloc, bpath, bquery, bfragment = \
- url_parse(base, allow_fragments=allow_fragments)
- scheme, netloc, path, query, fragment = \
- url_parse(url, bscheme, allow_fragments)
+ bscheme, bnetloc, bpath, bquery, bfragment = url_parse(
+ base, allow_fragments=allow_fragments
+ )
+ scheme, netloc, path, query, fragment = url_parse(url, bscheme, allow_fragments)
if scheme != bscheme:
return url
if netloc:
return url_unparse((scheme, netloc, path, query, fragment))
netloc = bnetloc
- if path[:1] == s('/'):
- segments = path.split(s('/'))
+ if path[:1] == s("/"):
+ segments = path.split(s("/"))
elif not path:
- segments = bpath.split(s('/'))
+ segments = bpath.split(s("/"))
if not query:
query = bquery
else:
- segments = bpath.split(s('/'))[:-1] + path.split(s('/'))
+ segments = bpath.split(s("/"))[:-1] + path.split(s("/"))
# If the rightmost part is "./" we want to keep the slash but
# remove the dot.
- if segments[-1] == s('.'):
- segments[-1] = s('')
+ if segments[-1] == s("."):
+ segments[-1] = s("")
# Resolve ".." and "."
- segments = [segment for segment in segments if segment != s('.')]
+ segments = [segment for segment in segments if segment != s(".")]
while 1:
i = 1
n = len(segments) - 1
while i < n:
- if segments[i] == s('..') and \
- segments[i - 1] not in (s(''), s('..')):
- del segments[i - 1:i + 1]
+ if segments[i] == s("..") and segments[i - 1] not in (s(""), s("..")):
+ del segments[i - 1 : i + 1]
break
i += 1
else:
break
# Remove trailing ".." if the URL is absolute
- unwanted_marker = [s(''), s('..')]
+ unwanted_marker = [s(""), s("..")]
while segments[:2] == unwanted_marker:
del segments[1]
- path = s('/').join(segments)
+ path = s("/").join(segments)
return url_unparse((scheme, netloc, path, query, fragment))
class Href(object):
-
"""Implements a callable that constructs URLs with the given base. The
function can be called with any number of positional and keyword
arguments which than are used to assemble the URL. Works with URLs
@@ -1054,39 +1090,45 @@ class Href(object):
`sort` and `key` were added.
"""
- def __init__(self, base='./', charset='utf-8', sort=False, key=None):
+ def __init__(self, base="./", charset="utf-8", sort=False, key=None):
if not base:
- base = './'
+ base = "./"
self.base = base
self.charset = charset
self.sort = sort
self.key = key
def __getattr__(self, name):
- if name[:2] == '__':
+ if name[:2] == "__":
raise AttributeError(name)
base = self.base
- if base[-1:] != '/':
- base += '/'
+ if base[-1:] != "/":
+ base += "/"
return Href(url_join(base, name), self.charset, self.sort, self.key)
def __call__(self, *path, **query):
if path and isinstance(path[-1], dict):
if query:
- raise TypeError('keyword arguments and query-dicts '
- 'can\'t be combined')
+ raise TypeError("keyword arguments and query-dicts can't be combined")
query, path = path[-1], path[:-1]
elif query:
- query = dict([(k.endswith('_') and k[:-1] or k, v)
- for k, v in query.items()])
- path = '/'.join([to_unicode(url_quote(x, self.charset), 'ascii')
- for x in path if x is not None]).lstrip('/')
+ query = dict(
+ [(k.endswith("_") and k[:-1] or k, v) for k, v in query.items()]
+ )
+ path = "/".join(
+ [
+ to_unicode(url_quote(x, self.charset), "ascii")
+ for x in path
+ if x is not None
+ ]
+ ).lstrip("/")
rv = self.base
if path:
- if not rv.endswith('/'):
- rv += '/'
- rv = url_join(rv, './' + path)
+ if not rv.endswith("/"):
+ rv += "/"
+ rv = url_join(rv, "./" + path)
if query:
- rv += '?' + to_unicode(url_encode(query, self.charset, sort=self.sort,
- key=self.key), 'ascii')
+ rv += "?" + to_unicode(
+ url_encode(query, self.charset, sort=self.sort, key=self.key), "ascii"
+ )
return to_native(rv)
diff --git a/src/werkzeug/useragents.py b/src/werkzeug/useragents.py
index 90d2649b..e265e093 100644
--- a/src/werkzeug/useragents.py
+++ b/src/werkzeug/useragents.py
@@ -12,80 +12,82 @@
:license: BSD-3-Clause
"""
import re
+import warnings
class UserAgentParser(object):
-
"""A simple user agent parser. Used by the `UserAgent`."""
platforms = (
- ('cros', 'chromeos'),
- ('iphone|ios', 'iphone'),
- ('ipad', 'ipad'),
- (r'darwin|mac|os\s*x', 'macos'),
- ('win', 'windows'),
- (r'android', 'android'),
- ('netbsd', 'netbsd'),
- ('openbsd', 'openbsd'),
- ('freebsd', 'freebsd'),
- ('dragonfly', 'dragonflybsd'),
- ('(sun|i86)os', 'solaris'),
- (r'x11|lin(\b|ux)?', 'linux'),
- (r'nintendo\s+wii', 'wii'),
- ('irix', 'irix'),
- ('hp-?ux', 'hpux'),
- ('aix', 'aix'),
- ('sco|unix_sv', 'sco'),
- ('bsd', 'bsd'),
- ('amiga', 'amiga'),
- ('blackberry|playbook', 'blackberry'),
- ('symbian', 'symbian')
+ ("cros", "chromeos"),
+ ("iphone|ios", "iphone"),
+ ("ipad", "ipad"),
+ (r"darwin|mac|os\s*x", "macos"),
+ ("win", "windows"),
+ (r"android", "android"),
+ ("netbsd", "netbsd"),
+ ("openbsd", "openbsd"),
+ ("freebsd", "freebsd"),
+ ("dragonfly", "dragonflybsd"),
+ ("(sun|i86)os", "solaris"),
+ (r"x11|lin(\b|ux)?", "linux"),
+ (r"nintendo\s+wii", "wii"),
+ ("irix", "irix"),
+ ("hp-?ux", "hpux"),
+ ("aix", "aix"),
+ ("sco|unix_sv", "sco"),
+ ("bsd", "bsd"),
+ ("amiga", "amiga"),
+ ("blackberry|playbook", "blackberry"),
+ ("symbian", "symbian"),
)
browsers = (
- ('googlebot', 'google'),
- ('msnbot', 'msn'),
- ('yahoo', 'yahoo'),
- ('ask jeeves', 'ask'),
- (r'aol|america\s+online\s+browser', 'aol'),
- ('opera', 'opera'),
- ('edge', 'edge'),
- ('chrome|crios', 'chrome'),
- ('seamonkey', 'seamonkey'),
- ('firefox|firebird|phoenix|iceweasel', 'firefox'),
- ('galeon', 'galeon'),
- ('safari|version', 'safari'),
- ('webkit', 'webkit'),
- ('camino', 'camino'),
- ('konqueror', 'konqueror'),
- ('k-meleon', 'kmeleon'),
- ('netscape', 'netscape'),
- (r'msie|microsoft\s+internet\s+explorer|trident/.+? rv:', 'msie'),
- ('lynx', 'lynx'),
- ('links', 'links'),
- ('Baiduspider', 'baidu'),
- ('bingbot', 'bing'),
- ('mozilla', 'mozilla')
+ ("googlebot", "google"),
+ ("msnbot", "msn"),
+ ("yahoo", "yahoo"),
+ ("ask jeeves", "ask"),
+ (r"aol|america\s+online\s+browser", "aol"),
+ ("opera", "opera"),
+ ("edge", "edge"),
+ ("chrome|crios", "chrome"),
+ ("seamonkey", "seamonkey"),
+ ("firefox|firebird|phoenix|iceweasel", "firefox"),
+ ("galeon", "galeon"),
+ ("safari|version", "safari"),
+ ("webkit", "webkit"),
+ ("camino", "camino"),
+ ("konqueror", "konqueror"),
+ ("k-meleon", "kmeleon"),
+ ("netscape", "netscape"),
+ (r"msie|microsoft\s+internet\s+explorer|trident/.+? rv:", "msie"),
+ ("lynx", "lynx"),
+ ("links", "links"),
+ ("Baiduspider", "baidu"),
+ ("bingbot", "bing"),
+ ("mozilla", "mozilla"),
)
- _browser_version_re = r'(?:%s)[/\sa-z(]*(\d+[.\da-z]+)?'
+ _browser_version_re = r"(?:%s)[/\sa-z(]*(\d+[.\da-z]+)?"
_language_re = re.compile(
- r'(?:;\s*|\s+)(\b\w{2}\b(?:-\b\w{2}\b)?)\s*;|'
- r'(?:\(|\[|;)\s*(\b\w{2}\b(?:-\b\w{2}\b)?)\s*(?:\]|\)|;)'
+ r"(?:;\s*|\s+)(\b\w{2}\b(?:-\b\w{2}\b)?)\s*;|"
+ r"(?:\(|\[|;)\s*(\b\w{2}\b(?:-\b\w{2}\b)?)\s*(?:\]|\)|;)"
)
def __init__(self):
self.platforms = [(b, re.compile(a, re.I)) for a, b in self.platforms]
- self.browsers = [(b, re.compile(self._browser_version_re % a, re.I))
- for a, b in self.browsers]
+ self.browsers = [
+ (b, re.compile(self._browser_version_re % a, re.I))
+ for a, b in self.browsers
+ ]
def __call__(self, user_agent):
- for platform, regex in self.platforms:
+ for platform, regex in self.platforms: # noqa: B007
match = regex.search(user_agent)
if match is not None:
break
else:
platform = None
- for browser, regex in self.browsers:
+ for browser, regex in self.browsers: # noqa: B007
match = regex.search(user_agent)
if match is not None:
version = match.group(1)
@@ -101,7 +103,6 @@ class UserAgentParser(object):
class UserAgent(object):
-
"""Represents a user agent. Pass it a WSGI environment or a user agent
string and you can inspect some of the details from the user agent
string via the attributes. The following attributes exist:
@@ -181,10 +182,11 @@ class UserAgent(object):
def __init__(self, environ_or_string):
if isinstance(environ_or_string, dict):
- environ_or_string = environ_or_string.get('HTTP_USER_AGENT', '')
+ environ_or_string = environ_or_string.get("HTTP_USER_AGENT", "")
self.string = environ_or_string
- self.platform, self.browser, self.version, self.language = \
- self._parser(environ_or_string)
+ self.platform, self.browser, self.version, self.language = self._parser(
+ environ_or_string
+ )
def to_header(self):
return self.string
@@ -198,16 +200,21 @@ class UserAgent(object):
__bool__ = __nonzero__
def __repr__(self):
- return '<%s %r/%s>' % (
- self.__class__.__name__,
- self.browser,
- self.version
- )
+ return "<%s %r/%s>" % (self.__class__.__name__, self.browser, self.version)
-# conceptionally this belongs in this module but because we want to lazily
-# load the user agent module (which happens in wrappers.py) we have to import
-# it afterwards. The class itself has the module set to this module so
-# pickle, inspect and similar modules treat the object as if it was really
-# implemented here.
-from werkzeug.wrappers import UserAgentMixin # noqa
+# DEPRECATED
+from .wrappers import UserAgentMixin as _UserAgentMixin
+
+
+class UserAgentMixin(_UserAgentMixin):
+ @property
+ def user_agent(self, *args, **kwargs):
+ warnings.warn(
+ "'werkzeug.useragents.UserAgentMixin' should be imported"
+ " from 'werkzeug.wrappers.UserAgentMixin'. This old import"
+ " will be removed in version 1.0.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return super(_UserAgentMixin, self).user_agent
diff --git a/src/werkzeug/utils.py b/src/werkzeug/utils.py
index 406f62b3..20620572 100644
--- a/src/werkzeug/utils.py
+++ b/src/werkzeug/utils.py
@@ -11,32 +11,47 @@
:license: BSD-3-Clause
"""
import codecs
-import re
import os
-import sys
import pkgutil
+import re
+import sys
import warnings
+from ._compat import iteritems
+from ._compat import PY2
+from ._compat import reraise
+from ._compat import string_types
+from ._compat import text_type
+from ._compat import unichr
+from ._internal import _DictAccessorProperty
+from ._internal import _missing
+from ._internal import _parse_signature
+
try:
from html.entities import name2codepoint
except ImportError:
from htmlentitydefs import name2codepoint
-from werkzeug._compat import unichr, text_type, string_types, iteritems, \
- reraise, PY2
-from werkzeug._internal import _DictAccessorProperty, \
- _parse_signature, _missing
-
-_format_re = re.compile(r'\$(?:(%s)|\{(%s)\})' % (('[a-zA-Z_][a-zA-Z0-9_]*',) * 2))
-_entity_re = re.compile(r'&([^;]+);')
-_filename_ascii_strip_re = re.compile(r'[^A-Za-z0-9_.-]')
-_windows_device_files = ('CON', 'AUX', 'COM1', 'COM2', 'COM3', 'COM4', 'LPT1',
- 'LPT2', 'LPT3', 'PRN', 'NUL')
+_format_re = re.compile(r"\$(?:(%s)|\{(%s)\})" % (("[a-zA-Z_][a-zA-Z0-9_]*",) * 2))
+_entity_re = re.compile(r"&([^;]+);")
+_filename_ascii_strip_re = re.compile(r"[^A-Za-z0-9_.-]")
+_windows_device_files = (
+ "CON",
+ "AUX",
+ "COM1",
+ "COM2",
+ "COM3",
+ "COM4",
+ "LPT1",
+ "LPT2",
+ "LPT3",
+ "PRN",
+ "NUL",
+)
class cached_property(property):
-
"""A decorator that converts a function into a lazy property. The
function wrapped is called the first time to retrieve the result
and then that calculated result is used the next time you access
@@ -79,7 +94,6 @@ class cached_property(property):
class environ_property(_DictAccessorProperty):
-
"""Maps request attributes to environment variables. This works not only
for the Werzeug request object, but also any other class with an
environ attribute:
@@ -107,7 +121,6 @@ class environ_property(_DictAccessorProperty):
class header_property(_DictAccessorProperty):
-
"""Like `environ_property` but for headers."""
def lookup(self, obj):
@@ -115,7 +128,6 @@ class header_property(_DictAccessorProperty):
class HTMLBuilder(object):
-
"""Helper object for HTML generation.
Per default there are two instances of that class. The `html` one, and
@@ -141,20 +153,45 @@ class HTMLBuilder(object):
u'<p>&lt;foo&gt;</p>'
"""
- _entity_re = re.compile(r'&([^;]+);')
+ _entity_re = re.compile(r"&([^;]+);")
_entities = name2codepoint.copy()
- _entities['apos'] = 39
- _empty_elements = set([
- 'area', 'base', 'basefont', 'br', 'col', 'command', 'embed', 'frame',
- 'hr', 'img', 'input', 'keygen', 'isindex', 'link', 'meta', 'param',
- 'source', 'wbr'
- ])
- _boolean_attributes = set([
- 'selected', 'checked', 'compact', 'declare', 'defer', 'disabled',
- 'ismap', 'multiple', 'nohref', 'noresize', 'noshade', 'nowrap'
- ])
- _plaintext_elements = set(['textarea'])
- _c_like_cdata = set(['script', 'style'])
+ _entities["apos"] = 39
+ _empty_elements = {
+ "area",
+ "base",
+ "basefont",
+ "br",
+ "col",
+ "command",
+ "embed",
+ "frame",
+ "hr",
+ "img",
+ "input",
+ "keygen",
+ "isindex",
+ "link",
+ "meta",
+ "param",
+ "source",
+ "wbr",
+ }
+ _boolean_attributes = {
+ "selected",
+ "checked",
+ "compact",
+ "declare",
+ "defer",
+ "disabled",
+ "ismap",
+ "multiple",
+ "nohref",
+ "noresize",
+ "noshade",
+ "nowrap",
+ }
+ _plaintext_elements = {"textarea"}
+ _c_like_cdata = {"script", "style"}
def __init__(self, dialect):
self._dialect = dialect
@@ -163,56 +200,56 @@ class HTMLBuilder(object):
return escape(s)
def __getattr__(self, tag):
- if tag[:2] == '__':
+ if tag[:2] == "__":
raise AttributeError(tag)
def proxy(*children, **arguments):
- buffer = '<' + tag
+ buffer = "<" + tag
for key, value in iteritems(arguments):
if value is None:
continue
- if key[-1] == '_':
+ if key[-1] == "_":
key = key[:-1]
if key in self._boolean_attributes:
if not value:
continue
- if self._dialect == 'xhtml':
+ if self._dialect == "xhtml":
value = '="' + key + '"'
else:
- value = ''
+ value = ""
else:
value = '="' + escape(value) + '"'
- buffer += ' ' + key + value
+ buffer += " " + key + value
if not children and tag in self._empty_elements:
- if self._dialect == 'xhtml':
- buffer += ' />'
+ if self._dialect == "xhtml":
+ buffer += " />"
else:
- buffer += '>'
+ buffer += ">"
return buffer
- buffer += '>'
+ buffer += ">"
- children_as_string = ''.join([text_type(x) for x in children
- if x is not None])
+ children_as_string = "".join(
+ [text_type(x) for x in children if x is not None]
+ )
if children_as_string:
if tag in self._plaintext_elements:
children_as_string = escape(children_as_string)
- elif tag in self._c_like_cdata and self._dialect == 'xhtml':
- children_as_string = '/*<![CDATA[*/' + \
- children_as_string + '/*]]>*/'
- buffer += children_as_string + '</' + tag + '>'
+ elif tag in self._c_like_cdata and self._dialect == "xhtml":
+ children_as_string = (
+ "/*<![CDATA[*/" + children_as_string + "/*]]>*/"
+ )
+ buffer += children_as_string + "</" + tag + ">"
return buffer
+
return proxy
def __repr__(self):
- return '<%s for %r>' % (
- self.__class__.__name__,
- self._dialect
- )
+ return "<%s for %r>" % (self.__class__.__name__, self._dialect)
-html = HTMLBuilder('html')
-xhtml = HTMLBuilder('xhtml')
+html = HTMLBuilder("html")
+xhtml = HTMLBuilder("xhtml")
# https://cgit.freedesktop.org/xdg/shared-mime-info/tree/freedesktop.org.xml.in
# https://www.iana.org/assignments/media-types/media-types.xhtml
@@ -243,11 +280,11 @@ def get_content_type(mimetype, charset):
``application/javascript`` are also given charsets.
"""
if (
- mimetype.startswith('text/')
+ mimetype.startswith("text/")
or mimetype in _charset_mimetypes
- or mimetype.endswith('+xml')
+ or mimetype.endswith("+xml")
):
- mimetype += '; charset=' + charset
+ mimetype += "; charset=" + charset
return mimetype
@@ -294,7 +331,7 @@ def detect_utf_encoding(data):
return "utf-16-le"
if len(head) == 2:
- return "utf-16-be" if head.startswith(b'\x00') else "utf-16-le"
+ return "utf-16-be" if head.startswith(b"\x00") else "utf-16-le"
return "utf-8"
@@ -311,11 +348,13 @@ def format_string(string, context):
:param string: the format string.
:param context: a dict with the variables to insert.
"""
+
def lookup_arg(match):
x = context[match.group(1) or match.group(2)]
if not isinstance(x, string_types):
x = type(string)(x)
return x
+
return _format_re.sub(lookup_arg, string)
@@ -345,21 +384,26 @@ def secure_filename(filename):
"""
if isinstance(filename, text_type):
from unicodedata import normalize
- filename = normalize('NFKD', filename).encode('ascii', 'ignore')
+
+ filename = normalize("NFKD", filename).encode("ascii", "ignore")
if not PY2:
- filename = filename.decode('ascii')
+ filename = filename.decode("ascii")
for sep in os.path.sep, os.path.altsep:
if sep:
- filename = filename.replace(sep, ' ')
- filename = str(_filename_ascii_strip_re.sub('', '_'.join(
- filename.split()))).strip('._')
+ filename = filename.replace(sep, " ")
+ filename = str(_filename_ascii_strip_re.sub("", "_".join(filename.split()))).strip(
+ "._"
+ )
# on nt a couple of special files are present in each folder. We
# have to ensure that the target file is not such a filename. In
# this case we prepend an underline
- if os.name == 'nt' and filename and \
- filename.split('.')[0].upper() in _windows_device_files:
- filename = '_' + filename
+ if (
+ os.name == "nt"
+ and filename
+ and filename.split(".")[0].upper() in _windows_device_files
+ ):
+ filename = "_" + filename
return filename
@@ -376,21 +420,26 @@ def escape(s, quote=None):
:param quote: ignored.
"""
if s is None:
- return ''
- elif hasattr(s, '__html__'):
+ return ""
+ elif hasattr(s, "__html__"):
return text_type(s.__html__())
elif not isinstance(s, string_types):
s = text_type(s)
if quote is not None:
from warnings import warn
+
warn(
"The 'quote' parameter is no longer used as of version 0.9"
" and will be removed in version 1.0.",
DeprecationWarning,
stacklevel=2,
)
- s = s.replace('&', '&amp;').replace('<', '&lt;') \
- .replace('>', '&gt;').replace('"', "&quot;")
+ s = (
+ s.replace("&", "&amp;")
+ .replace("<", "&lt;")
+ .replace(">", "&gt;")
+ .replace('"', "&quot;")
+ )
return s
@@ -400,18 +449,20 @@ def unescape(s):
:param s: the string to unescape.
"""
+
def handle_match(m):
name = m.group(1)
if name in HTMLBuilder._entities:
return unichr(HTMLBuilder._entities[name])
try:
- if name[:2] in ('#x', '#X'):
+ if name[:2] in ("#x", "#X"):
return unichr(int(name[2:], 16))
- elif name.startswith('#'):
+ elif name.startswith("#"):
return unichr(int(name[1:]))
except ValueError:
pass
- return u''
+ return u""
+
return _entity_re.sub(handle_match, s)
@@ -436,22 +487,26 @@ def redirect(location, code=302, Response=None):
unspecified.
"""
if Response is None:
- from werkzeug.wrappers import Response
+ from .wrappers import Response
display_location = escape(location)
if isinstance(location, text_type):
# Safe conversion is necessary here as we might redirect
# to a broken URI scheme (for instance itms-services).
- from werkzeug.urls import iri_to_uri
+ from .urls import iri_to_uri
+
location = iri_to_uri(location, safe_conversion=True)
response = Response(
'<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 3.2 Final//EN">\n'
- '<title>Redirecting...</title>\n'
- '<h1>Redirecting...</h1>\n'
- '<p>You should be redirected automatically to target URL: '
- '<a href="%s">%s</a>. If not click the link.' %
- (escape(location), display_location), code, mimetype='text/html')
- response.headers['Location'] = location
+ "<title>Redirecting...</title>\n"
+ "<h1>Redirecting...</h1>\n"
+ "<p>You should be redirected automatically to target URL: "
+ '<a href="%s">%s</a>. If not click the link.'
+ % (escape(location), display_location),
+ code,
+ mimetype="text/html",
+ )
+ response.headers["Location"] = location
return response
@@ -463,10 +518,10 @@ def append_slash_redirect(environ, code=301):
the redirect.
:param code: the status code for the redirect.
"""
- new_path = environ['PATH_INFO'].strip('/') + '/'
- query_string = environ.get('QUERY_STRING')
+ new_path = environ["PATH_INFO"].strip("/") + "/"
+ query_string = environ.get("QUERY_STRING")
if query_string:
- new_path += '?' + query_string
+ new_path += "?" + query_string
return redirect(new_path, code)
@@ -486,17 +541,17 @@ def import_string(import_name, silent=False):
# force the import name to automatically convert to strings
# __import__ is not able to handle unicode strings in the fromlist
# if the module is a package
- import_name = str(import_name).replace(':', '.')
+ import_name = str(import_name).replace(":", ".")
try:
try:
__import__(import_name)
except ImportError:
- if '.' not in import_name:
+ if "." not in import_name:
raise
else:
return sys.modules[import_name]
- module_name, obj_name = import_name.rsplit('.', 1)
+ module_name, obj_name = import_name.rsplit(".", 1)
module = __import__(module_name, globals(), locals(), [obj_name])
try:
return getattr(module, obj_name)
@@ -506,9 +561,8 @@ def import_string(import_name, silent=False):
except ImportError as e:
if not silent:
reraise(
- ImportStringError,
- ImportStringError(import_name, e),
- sys.exc_info()[2])
+ ImportStringError, ImportStringError(import_name, e), sys.exc_info()[2]
+ )
def find_modules(import_path, include_packages=False, recursive=False):
@@ -527,11 +581,11 @@ def find_modules(import_path, include_packages=False, recursive=False):
:return: generator
"""
module = import_string(import_path)
- path = getattr(module, '__path__', None)
+ path = getattr(module, "__path__", None)
if path is None:
- raise ValueError('%r is not a package' % import_path)
- basename = module.__name__ + '.'
- for importer, modname, ispkg in pkgutil.iter_modules(path):
+ raise ValueError("%r is not a package" % import_path)
+ basename = module.__name__ + "."
+ for _importer, modname, ispkg in pkgutil.iter_modules(path):
modname = basename + modname
if ispkg:
if include_packages:
@@ -608,24 +662,32 @@ def bind_arguments(func, args, kwargs):
:param kwargs: a dict of keyword arguments.
:return: a :class:`dict` of bound keyword arguments.
"""
- args, kwargs, missing, extra, extra_positional, \
- arg_spec, vararg_var, kwarg_var = _parse_signature(func)(args, kwargs)
+ (
+ args,
+ kwargs,
+ missing,
+ extra,
+ extra_positional,
+ arg_spec,
+ vararg_var,
+ kwarg_var,
+ ) = _parse_signature(func)(args, kwargs)
values = {}
- for (name, has_default, default), value in zip(arg_spec, args):
+ for (name, _has_default, _default), value in zip(arg_spec, args):
values[name] = value
if vararg_var is not None:
values[vararg_var] = tuple(extra_positional)
elif extra_positional:
- raise TypeError('too many positional arguments')
+ raise TypeError("too many positional arguments")
if kwarg_var is not None:
multikw = set(extra) & set([x[0] for x in arg_spec])
if multikw:
- raise TypeError('got multiple values for keyword argument '
- + repr(next(iter(multikw))))
+ raise TypeError(
+ "got multiple values for keyword argument " + repr(next(iter(multikw)))
+ )
values[kwarg_var] = extra
elif extra:
- raise TypeError('got unexpected keyword argument '
- + repr(next(iter(extra))))
+ raise TypeError("got unexpected keyword argument " + repr(next(iter(extra))))
return values
@@ -637,15 +699,14 @@ class ArgumentValidationError(ValueError):
self.missing = set(missing or ())
self.extra = extra or {}
self.extra_positional = extra_positional or []
- ValueError.__init__(self, 'function arguments invalid. ('
- '%d missing, %d additional)' % (
- len(self.missing),
- len(self.extra) + len(self.extra_positional)
- ))
+ ValueError.__init__(
+ self,
+ "function arguments invalid. (%d missing, %d additional)"
+ % (len(self.missing), len(self.extra) + len(self.extra_positional)),
+ )
class ImportStringError(ImportError):
-
"""Provides information about a failed :func:`import_string` attempt."""
#: String in dotted notation that failed to be imported.
@@ -658,47 +719,51 @@ class ImportStringError(ImportError):
self.exception = exception
msg = (
- 'import_string() failed for %r. Possible reasons are:\n\n'
- '- missing __init__.py in a package;\n'
- '- package or module path not included in sys.path;\n'
- '- duplicated package or module name taking precedence in '
- 'sys.path;\n'
- '- missing module, class, function or variable;\n\n'
- 'Debugged import:\n\n%s\n\n'
- 'Original exception:\n\n%s: %s')
-
- name = ''
+ "import_string() failed for %r. Possible reasons are:\n\n"
+ "- missing __init__.py in a package;\n"
+ "- package or module path not included in sys.path;\n"
+ "- duplicated package or module name taking precedence in "
+ "sys.path;\n"
+ "- missing module, class, function or variable;\n\n"
+ "Debugged import:\n\n%s\n\n"
+ "Original exception:\n\n%s: %s"
+ )
+
+ name = ""
tracked = []
- for part in import_name.replace(':', '.').split('.'):
- name += (name and '.') + part
+ for part in import_name.replace(":", ".").split("."):
+ name += (name and ".") + part
imported = import_string(name, silent=True)
if imported:
- tracked.append((name, getattr(imported, '__file__', None)))
+ tracked.append((name, getattr(imported, "__file__", None)))
else:
- track = ['- %r found in %r.' % (n, i) for n, i in tracked]
- track.append('- %r not found.' % name)
- msg = msg % (import_name, '\n'.join(track),
- exception.__class__.__name__, str(exception))
+ track = ["- %r found in %r." % (n, i) for n, i in tracked]
+ track.append("- %r not found." % name)
+ msg = msg % (
+ import_name,
+ "\n".join(track),
+ exception.__class__.__name__,
+ str(exception),
+ )
break
ImportError.__init__(self, msg)
def __repr__(self):
- return '<%s(%r, %r)>' % (self.__class__.__name__, self.import_name,
- self.exception)
+ return "<%s(%r, %r)>" % (
+ self.__class__.__name__,
+ self.import_name,
+ self.exception,
+ )
# DEPRECATED
-from werkzeug.datastructures import (
- MultiDict as _MultiDict,
- CombinedMultiDict as _CombinedMultiDict,
- Headers as _Headers,
- EnvironHeaders as _EnvironHeaders,
-)
-from werkzeug.http import (
- parse_cookie as _parse_cookie,
- dump_cookie as _dump_cookie,
-)
+from .datastructures import CombinedMultiDict as _CombinedMultiDict
+from .datastructures import EnvironHeaders as _EnvironHeaders
+from .datastructures import Headers as _Headers
+from .datastructures import MultiDict as _MultiDict
+from .http import dump_cookie as _dump_cookie
+from .http import parse_cookie as _parse_cookie
class MultiDict(_MultiDict):
diff --git a/src/werkzeug/wrappers/accept.py b/src/werkzeug/wrappers/accept.py
index 9fe50141..d0620a0a 100644
--- a/src/werkzeug/wrappers/accept.py
+++ b/src/werkzeug/wrappers/accept.py
@@ -17,15 +17,16 @@ class AcceptMixin(object):
"""List of mimetypes this client supports as
:class:`~werkzeug.datastructures.MIMEAccept` object.
"""
- return parse_accept_header(self.environ.get('HTTP_ACCEPT'), MIMEAccept)
+ return parse_accept_header(self.environ.get("HTTP_ACCEPT"), MIMEAccept)
@cached_property
def accept_charsets(self):
"""List of charsets this client supports as
:class:`~werkzeug.datastructures.CharsetAccept` object.
"""
- return parse_accept_header(self.environ.get('HTTP_ACCEPT_CHARSET'),
- CharsetAccept)
+ return parse_accept_header(
+ self.environ.get("HTTP_ACCEPT_CHARSET"), CharsetAccept
+ )
@cached_property
def accept_encodings(self):
@@ -33,7 +34,7 @@ class AcceptMixin(object):
are compression encodings such as gzip. For charsets have a look at
:attr:`accept_charset`.
"""
- return parse_accept_header(self.environ.get('HTTP_ACCEPT_ENCODING'))
+ return parse_accept_header(self.environ.get("HTTP_ACCEPT_ENCODING"))
@cached_property
def accept_languages(self):
@@ -44,5 +45,6 @@ class AcceptMixin(object):
In previous versions this was a regular
:class:`~werkzeug.datastructures.Accept` object.
"""
- return parse_accept_header(self.environ.get('HTTP_ACCEPT_LANGUAGE'),
- LanguageAccept)
+ return parse_accept_header(
+ self.environ.get("HTTP_ACCEPT_LANGUAGE"), LanguageAccept
+ )
diff --git a/src/werkzeug/wrappers/auth.py b/src/werkzeug/wrappers/auth.py
index 2660e9ba..714f7554 100644
--- a/src/werkzeug/wrappers/auth.py
+++ b/src/werkzeug/wrappers/auth.py
@@ -12,7 +12,7 @@ class AuthorizationMixin(object):
@cached_property
def authorization(self):
"""The `Authorization` object in parsed form."""
- header = self.environ.get('HTTP_AUTHORIZATION')
+ header = self.environ.get("HTTP_AUTHORIZATION")
return parse_authorization_header(header)
@@ -22,10 +22,12 @@ class WWWAuthenticateMixin(object):
@property
def www_authenticate(self):
"""The `WWW-Authenticate` header in a parsed form."""
+
def on_update(www_auth):
- if not www_auth and 'www-authenticate' in self.headers:
- del self.headers['www-authenticate']
+ if not www_auth and "www-authenticate" in self.headers:
+ del self.headers["www-authenticate"]
elif www_auth:
- self.headers['WWW-Authenticate'] = www_auth.to_header()
- header = self.headers.get('www-authenticate')
+ self.headers["WWW-Authenticate"] = www_auth.to_header()
+
+ header = self.headers.get("www-authenticate")
return parse_www_authenticate_header(header, on_update)
diff --git a/src/werkzeug/wrappers/base_request.py b/src/werkzeug/wrappers/base_request.py
index c106c9c4..f5c40be9 100644
--- a/src/werkzeug/wrappers/base_request.py
+++ b/src/werkzeug/wrappers/base_request.py
@@ -72,10 +72,10 @@ class BaseRequest(object):
"""
#: the charset for the request, defaults to utf-8
- charset = 'utf-8'
+ charset = "utf-8"
#: the error handling procedure for errors, defaults to 'replace'
- encoding_errors = 'replace'
+ encoding_errors = "replace"
#: the maximum content length. This is forwarded to the form data
#: parsing function (:func:`parse_form_data`). When set and the
@@ -150,7 +150,7 @@ class BaseRequest(object):
def __init__(self, environ, populate_request=True, shallow=False):
self.environ = environ
if populate_request and not shallow:
- self.environ['werkzeug.request'] = self
+ self.environ["werkzeug.request"] = self
self.shallow = shallow
def __repr__(self):
@@ -160,14 +160,11 @@ class BaseRequest(object):
args = []
try:
args.append("'%s'" % to_native(self.url, self.url_charset))
- args.append('[%s]' % self.method)
+ args.append("[%s]" % self.method)
except Exception:
- args.append('(invalid WSGI environ)')
+ args.append("(invalid WSGI environ)")
- return '<%s %s>' % (
- self.__class__.__name__,
- ' '.join(args)
- )
+ return "<%s %s>" % (self.__class__.__name__, " ".join(args))
@property
def url_charset(self):
@@ -197,9 +194,10 @@ class BaseRequest(object):
:return: request object
"""
- from werkzeug.test import EnvironBuilder
- charset = kwargs.pop('charset', cls.charset)
- kwargs['charset'] = charset
+ from ..test import EnvironBuilder
+
+ charset = kwargs.pop("charset", cls.charset)
+ kwargs["charset"] = charset
builder = EnvironBuilder(*args, **kwargs)
try:
return builder.get_request(cls)
@@ -228,7 +226,7 @@ class BaseRequest(object):
#: the request. The return value is then called with the latest
#: two arguments. This makes it possible to use this decorator for
#: both methods and standalone WSGI functions.
- from werkzeug.exceptions import HTTPException
+ from ..exceptions import HTTPException
def application(*args):
request = cls(args[-2])
@@ -241,8 +239,9 @@ class BaseRequest(object):
return update_wrapper(application, f)
- def _get_file_stream(self, total_content_length, content_type, filename=None,
- content_length=None):
+ def _get_file_stream(
+ self, total_content_length, content_type, filename=None, content_length=None
+ ):
"""Called to get a stream for the file upload.
This must provide a file-like class with `read()`, `readline()`
@@ -266,7 +265,8 @@ class BaseRequest(object):
total_content_length=total_content_length,
content_type=content_type,
filename=filename,
- content_length=content_length)
+ content_length=content_length,
+ )
@property
def want_form_data_parsed(self):
@@ -275,7 +275,7 @@ class BaseRequest(object):
.. versionadded:: 0.8
"""
- return bool(self.environ.get('CONTENT_TYPE'))
+ return bool(self.environ.get("CONTENT_TYPE"))
def make_form_data_parser(self):
"""Creates the form data parser. Instantiates the
@@ -283,12 +283,14 @@ class BaseRequest(object):
.. versionadded:: 0.8
"""
- return self.form_data_parser_class(self._get_file_stream,
- self.charset,
- self.encoding_errors,
- self.max_form_memory_size,
- self.max_content_length,
- self.parameter_storage_class)
+ return self.form_data_parser_class(
+ self._get_file_stream,
+ self.charset,
+ self.encoding_errors,
+ self.max_form_memory_size,
+ self.max_content_length,
+ self.parameter_storage_class,
+ )
def _load_form_data(self):
"""Method used internally to retrieve submitted data. After calling
@@ -300,26 +302,30 @@ class BaseRequest(object):
.. versionadded:: 0.8
"""
# abort early if we have already consumed the stream
- if 'form' in self.__dict__:
+ if "form" in self.__dict__:
return
_assert_not_shallow(self)
if self.want_form_data_parsed:
- content_type = self.environ.get('CONTENT_TYPE', '')
+ content_type = self.environ.get("CONTENT_TYPE", "")
content_length = get_content_length(self.environ)
mimetype, options = parse_options_header(content_type)
parser = self.make_form_data_parser()
- data = parser.parse(self._get_stream_for_parsing(),
- mimetype, content_length, options)
+ data = parser.parse(
+ self._get_stream_for_parsing(), mimetype, content_length, options
+ )
else:
- data = (self.stream, self.parameter_storage_class(),
- self.parameter_storage_class())
+ data = (
+ self.stream,
+ self.parameter_storage_class(),
+ self.parameter_storage_class(),
+ )
# inject the values into the instance dict so that we bypass
# our cached_property non-data descriptor.
d = self.__dict__
- d['stream'], d['form'], d['files'] = data
+ d["stream"], d["form"], d["files"] = data
def _get_stream_for_parsing(self):
"""This is the same as accessing :attr:`stream` with the difference
@@ -328,7 +334,7 @@ class BaseRequest(object):
.. versionadded:: 0.9.3
"""
- cached_data = getattr(self, '_cached_data', None)
+ cached_data = getattr(self, "_cached_data", None)
if cached_data is not None:
return BytesIO(cached_data)
return self.stream
@@ -340,8 +346,8 @@ class BaseRequest(object):
.. versionadded:: 0.9
"""
- files = self.__dict__.get('files')
- for key, value in iter_multi_items(files or ()):
+ files = self.__dict__.get("files")
+ for _key, value in iter_multi_items(files or ()):
value.close()
def __enter__(self):
@@ -371,12 +377,14 @@ class BaseRequest(object):
_assert_not_shallow(self)
return get_input_stream(self.environ)
- input_stream = environ_property('wsgi.input', """
- The WSGI input stream.
+ input_stream = environ_property(
+ "wsgi.input",
+ """The WSGI input stream.
- In general it's a bad idea to use this one because you can easily read past
- the boundary. Use the :attr:`stream` instead.
- """)
+ In general it's a bad idea to use this one because you can
+ easily read past the boundary. Use the :attr:`stream`
+ instead.""",
+ )
@cached_property
def args(self):
@@ -389,9 +397,12 @@ class BaseRequest(object):
:attr:`parameter_storage_class` to a different type. This might
be necessary if the order of the form data is important.
"""
- return url_decode(wsgi_get_bytes(self.environ.get('QUERY_STRING', '')),
- self.url_charset, errors=self.encoding_errors,
- cls=self.parameter_storage_class)
+ return url_decode(
+ wsgi_get_bytes(self.environ.get("QUERY_STRING", "")),
+ self.url_charset,
+ errors=self.encoding_errors,
+ cls=self.parameter_storage_class,
+ )
@cached_property
def data(self):
@@ -401,7 +412,7 @@ class BaseRequest(object):
"""
if self.disable_data_descriptor:
- raise AttributeError('data descriptor is disabled')
+ raise AttributeError("data descriptor is disabled")
# XXX: this should eventually be deprecated.
# We trigger form data parsing first which means that the descriptor
@@ -436,7 +447,7 @@ class BaseRequest(object):
.. versionadded:: 0.9
"""
- rv = getattr(self, '_cached_data', None)
+ rv = getattr(self, "_cached_data", None)
if rv is None:
if parse_form_data:
self._load_form_data()
@@ -504,9 +515,12 @@ class BaseRequest(object):
def cookies(self):
"""A :class:`dict` with the contents of all cookies transmitted with
the request."""
- return parse_cookie(self.environ, self.charset,
- self.encoding_errors,
- cls=self.dict_storage_class)
+ return parse_cookie(
+ self.environ,
+ self.charset,
+ self.encoding_errors,
+ cls=self.dict_storage_class,
+ )
@cached_property
def headers(self):
@@ -521,37 +535,39 @@ class BaseRequest(object):
info in the WSGI environment but will always include a leading slash,
even if the URL root is accessed.
"""
- raw_path = wsgi_decoding_dance(self.environ.get('PATH_INFO') or '',
- self.charset, self.encoding_errors)
- return '/' + raw_path.lstrip('/')
+ raw_path = wsgi_decoding_dance(
+ self.environ.get("PATH_INFO") or "", self.charset, self.encoding_errors
+ )
+ return "/" + raw_path.lstrip("/")
@cached_property
def full_path(self):
"""Requested path as unicode, including the query string."""
- return self.path + u'?' + to_unicode(self.query_string, self.url_charset)
+ return self.path + u"?" + to_unicode(self.query_string, self.url_charset)
@cached_property
def script_root(self):
"""The root path of the script without the trailing slash."""
- raw_path = wsgi_decoding_dance(self.environ.get('SCRIPT_NAME') or '',
- self.charset, self.encoding_errors)
- return raw_path.rstrip('/')
+ raw_path = wsgi_decoding_dance(
+ self.environ.get("SCRIPT_NAME") or "", self.charset, self.encoding_errors
+ )
+ return raw_path.rstrip("/")
@cached_property
def url(self):
"""The reconstructed current URL as IRI.
See also: :attr:`trusted_hosts`.
"""
- return get_current_url(self.environ,
- trusted_hosts=self.trusted_hosts)
+ return get_current_url(self.environ, trusted_hosts=self.trusted_hosts)
@cached_property
def base_url(self):
"""Like :attr:`url` but without the querystring
See also: :attr:`trusted_hosts`.
"""
- return get_current_url(self.environ, strip_querystring=True,
- trusted_hosts=self.trusted_hosts)
+ return get_current_url(
+ self.environ, strip_querystring=True, trusted_hosts=self.trusted_hosts
+ )
@cached_property
def url_root(self):
@@ -559,16 +575,16 @@ class BaseRequest(object):
root as IRI.
See also: :attr:`trusted_hosts`.
"""
- return get_current_url(self.environ, True,
- trusted_hosts=self.trusted_hosts)
+ return get_current_url(self.environ, True, trusted_hosts=self.trusted_hosts)
@cached_property
def host_url(self):
"""Just the host with scheme as IRI.
See also: :attr:`trusted_hosts`.
"""
- return get_current_url(self.environ, host_only=True,
- trusted_hosts=self.trusted_hosts)
+ return get_current_url(
+ self.environ, host_only=True, trusted_hosts=self.trusted_hosts
+ )
@cached_property
def host(self):
@@ -578,39 +594,51 @@ class BaseRequest(object):
return get_host(self.environ, trusted_hosts=self.trusted_hosts)
query_string = environ_property(
- 'QUERY_STRING', '', read_only=True,
- load_func=wsgi_get_bytes, doc='The URL parameters as raw bytestring.')
+ "QUERY_STRING",
+ "",
+ read_only=True,
+ load_func=wsgi_get_bytes,
+ doc="The URL parameters as raw bytestring.",
+ )
method = environ_property(
- 'REQUEST_METHOD', 'GET', read_only=True,
+ "REQUEST_METHOD",
+ "GET",
+ read_only=True,
load_func=lambda x: x.upper(),
- doc="The request method. (For example ``'GET'`` or ``'POST'``).")
+ doc="The request method. (For example ``'GET'`` or ``'POST'``).",
+ )
@cached_property
def access_route(self):
"""If a forwarded header exists this is a list of all ip addresses
from the client ip to the last proxy server.
"""
- if 'HTTP_X_FORWARDED_FOR' in self.environ:
- addr = self.environ['HTTP_X_FORWARDED_FOR'].split(',')
+ if "HTTP_X_FORWARDED_FOR" in self.environ:
+ addr = self.environ["HTTP_X_FORWARDED_FOR"].split(",")
return self.list_storage_class([x.strip() for x in addr])
- elif 'REMOTE_ADDR' in self.environ:
- return self.list_storage_class([self.environ['REMOTE_ADDR']])
+ elif "REMOTE_ADDR" in self.environ:
+ return self.list_storage_class([self.environ["REMOTE_ADDR"]])
return self.list_storage_class()
@property
def remote_addr(self):
"""The remote address of the client."""
- return self.environ.get('REMOTE_ADDR')
-
- remote_user = environ_property('REMOTE_USER', doc='''
- If the server supports user authentication, and the script is
- protected, this attribute contains the username the user has
- authenticated as.''')
-
- scheme = environ_property('wsgi.url_scheme', doc='''
+ return self.environ.get("REMOTE_ADDR")
+
+ remote_user = environ_property(
+ "REMOTE_USER",
+ doc="""If the server supports user authentication, and the
+ script is protected, this attribute contains the username the
+ user has authenticated as.""",
+ )
+
+ scheme = environ_property(
+ "wsgi.url_scheme",
+ doc="""
URL scheme (http or https).
- .. versionadded:: 0.7''')
+ .. versionadded:: 0.7""",
+ )
@property
def is_xhr(self):
@@ -630,28 +658,36 @@ class BaseRequest(object):
" is not standard and is unreliable. You may be able to use"
" 'accept_mimetypes' instead.",
DeprecationWarning,
- stacklevel=2
+ stacklevel=2,
)
- return self.environ.get(
- 'HTTP_X_REQUESTED_WITH', ''
- ).lower() == 'xmlhttprequest'
-
- is_secure = property(lambda self: self.environ['wsgi.url_scheme'] == 'https',
- doc='`True` if the request is secure.')
- is_multithread = environ_property('wsgi.multithread', doc='''
- boolean that is `True` if the application is served by
- a multithreaded WSGI server.''')
- is_multiprocess = environ_property('wsgi.multiprocess', doc='''
- boolean that is `True` if the application is served by
- a WSGI server that spawns multiple processes.''')
- is_run_once = environ_property('wsgi.run_once', doc='''
- boolean that is `True` if the application will be executed only
- once in a process lifetime. This is the case for CGI for example,
- but it's not guaranteed that the execution only happens one time.''')
+ return self.environ.get("HTTP_X_REQUESTED_WITH", "").lower() == "xmlhttprequest"
+
+ is_secure = property(
+ lambda self: self.environ["wsgi.url_scheme"] == "https",
+ doc="`True` if the request is secure.",
+ )
+ is_multithread = environ_property(
+ "wsgi.multithread",
+ doc="""boolean that is `True` if the application is served by a
+ multithreaded WSGI server.""",
+ )
+ is_multiprocess = environ_property(
+ "wsgi.multiprocess",
+ doc="""boolean that is `True` if the application is served by a
+ WSGI server that spawns multiple processes.""",
+ )
+ is_run_once = environ_property(
+ "wsgi.run_once",
+ doc="""boolean that is `True` if the application will be
+ executed only once in a process lifetime. This is the case for
+ CGI for example, but it's not guaranteed that the execution only
+ happens one time.""",
+ )
def _assert_not_shallow(request):
if request.shallow:
- raise RuntimeError('A shallow request tried to consume '
- 'form data. If you really want to do '
- 'that, set `shallow` to False.')
+ raise RuntimeError(
+ "A shallow request tried to consume form data. If you really"
+ " want to do that, set `shallow` to False."
+ )
diff --git a/src/werkzeug/wrappers/base_response.py b/src/werkzeug/wrappers/base_response.py
index d7a8763b..d944a7d2 100644
--- a/src/werkzeug/wrappers/base_response.py
+++ b/src/werkzeug/wrappers/base_response.py
@@ -22,6 +22,7 @@ def _run_wsgi_app(*args):
"""
global _run_wsgi_app
from ..test import run_wsgi_app as _run_wsgi_app
+
return _run_wsgi_app(*args)
@@ -36,7 +37,7 @@ def _warn_if_string(iterable):
" client one character at a time. This is almost never"
" intended behavior, use 'response.data' to assign strings"
" to the response object.",
- stacklevel=2
+ stacklevel=2,
)
@@ -129,13 +130,13 @@ class BaseResponse(object):
"""
#: the charset of the response.
- charset = 'utf-8'
+ charset = "utf-8"
#: the default status if none is provided.
default_status = 200
#: the default mimetype if none is provided.
- default_mimetype = 'text/plain'
+ default_mimetype = "text/plain"
#: if set to `False` accessing properties on the response object will
#: not try to consume the response iterator and convert it into a list.
@@ -169,8 +170,15 @@ class BaseResponse(object):
#: .. _`cookie`: http://browsercookielimits.squawky.net/
max_cookie_size = 4093
- def __init__(self, response=None, status=None, headers=None,
- mimetype=None, content_type=None, direct_passthrough=False):
+ def __init__(
+ self,
+ response=None,
+ status=None,
+ headers=None,
+ mimetype=None,
+ content_type=None,
+ direct_passthrough=False,
+ ):
if isinstance(headers, Headers):
self.headers = headers
elif not headers:
@@ -179,13 +187,13 @@ class BaseResponse(object):
self.headers = Headers(headers)
if content_type is None:
- if mimetype is None and 'content-type' not in self.headers:
+ if mimetype is None and "content-type" not in self.headers:
mimetype = self.default_mimetype
if mimetype is not None:
mimetype = get_content_type(mimetype, self.charset)
content_type = mimetype
if content_type is not None:
- self.headers['Content-Type'] = content_type
+ self.headers["Content-Type"] = content_type
if status is None:
status = self.default_status
if isinstance(status, integer_types):
@@ -218,14 +226,10 @@ class BaseResponse(object):
def __repr__(self):
if self.is_sequence:
- body_info = '%d bytes' % sum(map(len, self.iter_encoded()))
+ body_info = "%d bytes" % sum(map(len, self.iter_encoded()))
else:
- body_info = 'streamed' if self.is_streamed else 'likely-streamed'
- return '<%s %s [%s]>' % (
- self.__class__.__name__,
- body_info,
- self.status
- )
+ body_info = "streamed" if self.is_streamed else "likely-streamed"
+ return "<%s %s [%s]>" % (self.__class__.__name__, body_info, self.status)
@classmethod
def force_type(cls, response, environ=None):
@@ -258,8 +262,10 @@ class BaseResponse(object):
"""
if not isinstance(response, BaseResponse):
if environ is None:
- raise TypeError('cannot convert WSGI application into '
- 'response objects without an environ')
+ raise TypeError(
+ "cannot convert WSGI application into response"
+ " objects without an environ"
+ )
response = BaseResponse(*_run_wsgi_app(response, environ))
response.__class__ = cls
return response
@@ -286,11 +292,13 @@ class BaseResponse(object):
def _set_status_code(self, code):
self._status_code = code
try:
- self._status = '%d %s' % (code, HTTP_STATUS_CODES[code].upper())
+ self._status = "%d %s" % (code, HTTP_STATUS_CODES[code].upper())
except KeyError:
- self._status = '%d UNKNOWN' % code
- status_code = property(_get_status_code, _set_status_code,
- doc='The HTTP Status code as number')
+ self._status = "%d UNKNOWN" % code
+
+ status_code = property(
+ _get_status_code, _set_status_code, doc="The HTTP Status code as number"
+ )
del _get_status_code, _set_status_code
def _get_status(self):
@@ -300,16 +308,17 @@ class BaseResponse(object):
try:
self._status = to_native(value)
except AttributeError:
- raise TypeError('Invalid status argument')
+ raise TypeError("Invalid status argument")
try:
self._status_code = int(self._status.split(None, 1)[0])
except ValueError:
self._status_code = 0
- self._status = '0 %s' % self._status
+ self._status = "0 %s" % self._status
except IndexError:
- raise ValueError('Empty status argument')
- status = property(_get_status, _set_status, doc='The HTTP Status code')
+ raise ValueError("Empty status argument")
+
+ status = property(_get_status, _set_status, doc="The HTTP Status code")
del _get_status, _set_status
def get_data(self, as_text=False):
@@ -326,7 +335,7 @@ class BaseResponse(object):
.. versionadded:: 0.9
"""
self._ensure_sequence()
- rv = b''.join(self.iter_encoded())
+ rv = b"".join(self.iter_encoded())
if as_text:
rv = rv.decode(self.charset)
return rv
@@ -346,12 +355,12 @@ class BaseResponse(object):
value = bytes(value)
self.response = [value]
if self.automatically_set_content_length:
- self.headers['Content-Length'] = str(len(value))
+ self.headers["Content-Length"] = str(len(value))
data = property(
get_data,
set_data,
- doc='A descriptor that calls :meth:`get_data` and :meth:`set_data`.'
+ doc="A descriptor that calls :meth:`get_data` and :meth:`set_data`.",
)
def calculate_content_length(self):
@@ -375,14 +384,16 @@ class BaseResponse(object):
self.response = list(self.response)
return
if self.direct_passthrough:
- raise RuntimeError('Attempted implicit sequence conversion '
- 'but the response object is in direct '
- 'passthrough mode.')
+ raise RuntimeError(
+ "Attempted implicit sequence conversion but the"
+ " response object is in direct passthrough mode."
+ )
if not self.implicit_sequence_conversion:
- raise RuntimeError('The response object required the iterable '
- 'to be a sequence, but the implicit '
- 'conversion was disabled. Call '
- 'make_sequence() yourself.')
+ raise RuntimeError(
+ "The response object required the iterable to be a"
+ " sequence, but the implicit conversion was disabled."
+ " Call make_sequence() yourself."
+ )
self.make_sequence()
def make_sequence(self):
@@ -397,7 +408,7 @@ class BaseResponse(object):
# if we consume an iterable we have to ensure that the close
# method of the iterable is called if available when we tear
# down the response
- close = getattr(self.response, 'close', None)
+ close = getattr(self.response, "close", None)
self.response = list(self.iter_encoded())
if close is not None:
self.call_on_close(close)
@@ -415,9 +426,18 @@ class BaseResponse(object):
# value from get_app_iter or iter_encoded.
return _iter_encoded(self.response, self.charset)
- def set_cookie(self, key, value='', max_age=None, expires=None,
- path='/', domain=None, secure=False, httponly=False,
- samesite=None):
+ def set_cookie(
+ self,
+ key,
+ value="",
+ max_age=None,
+ expires=None,
+ path="/",
+ domain=None,
+ secure=False,
+ httponly=False,
+ samesite=None,
+ ):
"""Sets a cookie. The parameters are the same as in the cookie `Morsel`
object in the Python standard library but it accepts unicode data, too.
@@ -445,21 +465,24 @@ class BaseResponse(object):
be attached to requests if those requests are
"same-site".
"""
- self.headers.add('Set-Cookie', dump_cookie(
- key,
- value=value,
- max_age=max_age,
- expires=expires,
- path=path,
- domain=domain,
- secure=secure,
- httponly=httponly,
- charset=self.charset,
- max_size=self.max_cookie_size,
- samesite=samesite
- ))
-
- def delete_cookie(self, key, path='/', domain=None):
+ self.headers.add(
+ "Set-Cookie",
+ dump_cookie(
+ key,
+ value=value,
+ max_age=max_age,
+ expires=expires,
+ path=path,
+ domain=domain,
+ secure=secure,
+ httponly=httponly,
+ charset=self.charset,
+ max_size=self.max_cookie_size,
+ samesite=samesite,
+ ),
+ )
+
+ def delete_cookie(self, key, path="/", domain=None):
"""Delete a cookie. Fails silently if key doesn't exist.
:param key: the key (name) of the cookie to be deleted.
@@ -503,7 +526,7 @@ class BaseResponse(object):
.. versionadded:: 0.9
Can now be used in a with statement.
"""
- if hasattr(self.response, 'close'):
+ if hasattr(self.response, "close"):
self.response.close()
for func in self._on_close:
func()
@@ -525,7 +548,7 @@ class BaseResponse(object):
# we explicitly set the length to a list of the *encoded* response
# iterator. Even if the implicit sequence conversion is disabled.
self.response = list(self.iter_encoded())
- self.headers['Content-Length'] = str(sum(map(len, self.response)))
+ self.headers["Content-Length"] = str(sum(map(len, self.response)))
def get_wsgi_headers(self, environ):
"""This is automatically called right before the response is started
@@ -562,11 +585,11 @@ class BaseResponse(object):
# speedup.
for key, value in headers:
ikey = key.lower()
- if ikey == u'location':
+ if ikey == u"location":
location = value
- elif ikey == u'content-location':
+ elif ikey == u"content-location":
content_location = value
- elif ikey == u'content-length':
+ elif ikey == u"content-length":
content_length = value
# make sure the location header is an absolute URL
@@ -583,17 +606,17 @@ class BaseResponse(object):
current_url = iri_to_uri(current_url)
location = url_join(current_url, location)
if location != old_location:
- headers['Location'] = location
+ headers["Location"] = location
# make sure the content location is a URL
- if content_location is not None and \
- isinstance(content_location, text_type):
- headers['Content-Location'] = iri_to_uri(content_location)
+ if content_location is not None and isinstance(content_location, text_type):
+ headers["Content-Location"] = iri_to_uri(content_location)
if 100 <= status < 200 or status == 204:
- # Per section 3.3.2 of RFC 7230, "a server MUST NOT send a Content-Length header field
- # in any response with a status code of 1xx (Informational) or 204 (No Content)."
- headers.remove('Content-Length')
+ # Per section 3.3.2 of RFC 7230, "a server MUST NOT send a
+ # Content-Length header field in any response with a status
+ # code of 1xx (Informational) or 204 (No Content)."
+ headers.remove("Content-Length")
elif status == 304:
remove_entity_headers(headers)
@@ -602,19 +625,21 @@ class BaseResponse(object):
# flattening the iterator or encoding of unicode strings in
# the response. We however should not do that if we have a 304
# response.
- if self.automatically_set_content_length and \
- self.is_sequence and content_length is None and \
- status not in (204, 304) and \
- not (100 <= status < 200):
+ if (
+ self.automatically_set_content_length
+ and self.is_sequence
+ and content_length is None
+ and status not in (204, 304)
+ and not (100 <= status < 200)
+ ):
try:
- content_length = sum(len(to_bytes(x, 'ascii'))
- for x in self.response)
+ content_length = sum(len(to_bytes(x, "ascii")) for x in self.response)
except UnicodeError:
# aha, something non-bytestringy in there, too bad, we
# can't safely figure out the length of the response.
pass
else:
- headers['Content-Length'] = str(content_length)
+ headers["Content-Length"] = str(content_length)
return headers
@@ -633,8 +658,11 @@ class BaseResponse(object):
:return: a response iterable.
"""
status = self.status_code
- if environ['REQUEST_METHOD'] == 'HEAD' or \
- 100 <= status < 200 or status in (204, 304):
+ if (
+ environ["REQUEST_METHOD"] == "HEAD"
+ or 100 <= status < 200
+ or status in (204, 304)
+ ):
iterable = ()
elif self.direct_passthrough:
if __debug__:
diff --git a/src/werkzeug/wrappers/common_descriptors.py b/src/werkzeug/wrappers/common_descriptors.py
index 3ad0474a..e4107ee0 100644
--- a/src/werkzeug/wrappers/common_descriptors.py
+++ b/src/werkzeug/wrappers/common_descriptors.py
@@ -26,11 +26,13 @@ class CommonRequestDescriptorsMixin(object):
.. versionadded:: 0.5
"""
- content_type = environ_property('CONTENT_TYPE', doc='''
- The Content-Type entity-header field indicates the media type of
- the entity-body sent to the recipient or, in the case of the HEAD
- method, the media type that would have been sent had the request
- been a GET.''')
+ content_type = environ_property(
+ "CONTENT_TYPE",
+ doc="""The Content-Type entity-header field indicates the media
+ type of the entity-body sent to the recipient or, in the case of
+ the HEAD method, the media type that would have been sent had
+ the request been a GET.""",
+ )
@cached_property
def content_length(self):
@@ -41,40 +43,58 @@ class CommonRequestDescriptorsMixin(object):
"""
return get_content_length(self.environ)
- content_encoding = environ_property('HTTP_CONTENT_ENCODING', doc='''
- The Content-Encoding entity-header field is used as a modifier to the
- media-type. When present, its value indicates what additional content
- codings have been applied to the entity-body, and thus what decoding
- mechanisms must be applied in order to obtain the media-type
- referenced by the Content-Type header field.
-
- .. versionadded:: 0.9''')
- content_md5 = environ_property('HTTP_CONTENT_MD5', doc='''
- The Content-MD5 entity-header field, as defined in RFC 1864, is an
- MD5 digest of the entity-body for the purpose of providing an
- end-to-end message integrity check (MIC) of the entity-body. (Note:
- a MIC is good for detecting accidental modification of the
- entity-body in transit, but is not proof against malicious attacks.)
-
- .. versionadded:: 0.9''')
- referrer = environ_property('HTTP_REFERER', doc='''
- The Referer[sic] request-header field allows the client to specify,
- for the server's benefit, the address (URI) of the resource from which
- the Request-URI was obtained (the "referrer", although the header
- field is misspelled).''')
- date = environ_property('HTTP_DATE', None, parse_date, doc='''
- The Date general-header field represents the date and time at which
- the message was originated, having the same semantics as orig-date
- in RFC 822.''')
- max_forwards = environ_property('HTTP_MAX_FORWARDS', None, int, doc='''
- The Max-Forwards request-header field provides a mechanism with the
- TRACE and OPTIONS methods to limit the number of proxies or gateways
- that can forward the request to the next inbound server.''')
+ content_encoding = environ_property(
+ "HTTP_CONTENT_ENCODING",
+ doc="""The Content-Encoding entity-header field is used as a
+ modifier to the media-type. When present, its value indicates
+ what additional content codings have been applied to the
+ entity-body, and thus what decoding mechanisms must be applied
+ in order to obtain the media-type referenced by the Content-Type
+ header field.
+
+ .. versionadded:: 0.9""",
+ )
+ content_md5 = environ_property(
+ "HTTP_CONTENT_MD5",
+ doc="""The Content-MD5 entity-header field, as defined in
+ RFC 1864, is an MD5 digest of the entity-body for the purpose of
+ providing an end-to-end message integrity check (MIC) of the
+ entity-body. (Note: a MIC is good for detecting accidental
+ modification of the entity-body in transit, but is not proof
+ against malicious attacks.)
+
+ .. versionadded:: 0.9""",
+ )
+ referrer = environ_property(
+ "HTTP_REFERER",
+ doc="""The Referer[sic] request-header field allows the client
+ to specify, for the server's benefit, the address (URI) of the
+ resource from which the Request-URI was obtained (the
+ "referrer", although the header field is misspelled).""",
+ )
+ date = environ_property(
+ "HTTP_DATE",
+ None,
+ parse_date,
+ doc="""The Date general-header field represents the date and
+ time at which the message was originated, having the same
+ semantics as orig-date in RFC 822.""",
+ )
+ max_forwards = environ_property(
+ "HTTP_MAX_FORWARDS",
+ None,
+ int,
+ doc="""The Max-Forwards request-header field provides a
+ mechanism with the TRACE and OPTIONS methods to limit the number
+ of proxies or gateways that can forward the request to the next
+ inbound server.""",
+ )
def _parse_content_type(self):
- if not hasattr(self, '_parsed_content_type'):
- self._parsed_content_type = \
- parse_options_header(self.environ.get('CONTENT_TYPE', ''))
+ if not hasattr(self, "_parsed_content_type"):
+ self._parsed_content_type = parse_options_header(
+ self.environ.get("CONTENT_TYPE", "")
+ )
@property
def mimetype(self):
@@ -103,7 +123,7 @@ class CommonRequestDescriptorsMixin(object):
optional behavior from the viewpoint of the protocol; however, some
systems MAY require that behavior be consistent with the directives.
"""
- return parse_set_header(self.environ.get('HTTP_PRAGMA', ''))
+ return parse_set_header(self.environ.get("HTTP_PRAGMA", ""))
class CommonResponseDescriptorsMixin(object):
@@ -112,115 +132,157 @@ class CommonResponseDescriptorsMixin(object):
HTTP headers with automatic type conversion.
"""
- def _get_mimetype(self):
- ct = self.headers.get('content-type')
+ @property
+ def mimetype(self):
+ """The mimetype (content type without charset etc.)"""
+ ct = self.headers.get("content-type")
if ct:
- return ct.split(';')[0].strip()
+ return ct.split(";")[0].strip()
+
+ @mimetype.setter
+ def mimetype(self, value):
+ self.headers["Content-Type"] = get_content_type(value, self.charset)
+
+ @property
+ def mimetype_params(self):
+ """The mimetype parameters as dict. For example if the
+ content type is ``text/html; charset=utf-8`` the params would be
+ ``{'charset': 'utf-8'}``.
- def _set_mimetype(self, value):
- self.headers['Content-Type'] = get_content_type(value, self.charset)
+ .. versionadded:: 0.5
+ """
- def _get_mimetype_params(self):
def on_update(d):
- self.headers['Content-Type'] = \
- dump_options_header(self.mimetype, d)
- d = parse_options_header(self.headers.get('content-type', ''))[1]
+ self.headers["Content-Type"] = dump_options_header(self.mimetype, d)
+
+ d = parse_options_header(self.headers.get("content-type", ""))[1]
return CallbackDict(d, on_update)
- mimetype = property(_get_mimetype, _set_mimetype, doc='''
- The mimetype (content type without charset etc.)''')
- mimetype_params = property(_get_mimetype_params, doc='''
- The mimetype parameters as dict. For example if the content
- type is ``text/html; charset=utf-8`` the params would be
- ``{'charset': 'utf-8'}``.
+ location = header_property(
+ "Location",
+ doc="""The Location response-header field is used to redirect
+ the recipient to a location other than the Request-URI for
+ completion of the request or identification of a new
+ resource.""",
+ )
+ age = header_property(
+ "Age",
+ None,
+ parse_age,
+ dump_age,
+ doc="""The Age response-header field conveys the sender's
+ estimate of the amount of time since the response (or its
+ revalidation) was generated at the origin server.
- .. versionadded:: 0.5
- ''')
- location = header_property('Location', doc='''
- The Location response-header field is used to redirect the recipient
- to a location other than the Request-URI for completion of the request
- or identification of a new resource.''')
- age = header_property('Age', None, parse_age, dump_age, doc='''
- The Age response-header field conveys the sender's estimate of the
- amount of time since the response (or its revalidation) was
- generated at the origin server.
-
- Age values are non-negative decimal integers, representing time in
- seconds.''')
- content_type = header_property('Content-Type', doc='''
- The Content-Type entity-header field indicates the media type of the
- entity-body sent to the recipient or, in the case of the HEAD method,
- the media type that would have been sent had the request been a GET.
- ''')
- content_length = header_property('Content-Length', None, int, str, doc='''
- The Content-Length entity-header field indicates the size of the
- entity-body, in decimal number of OCTETs, sent to the recipient or,
- in the case of the HEAD method, the size of the entity-body that would
- have been sent had the request been a GET.''')
- content_location = header_property('Content-Location', doc='''
- The Content-Location entity-header field MAY be used to supply the
- resource location for the entity enclosed in the message when that
- entity is accessible from a location separate from the requested
- resource's URI.''')
- content_encoding = header_property('Content-Encoding', doc='''
- The Content-Encoding entity-header field is used as a modifier to the
- media-type. When present, its value indicates what additional content
- codings have been applied to the entity-body, and thus what decoding
- mechanisms must be applied in order to obtain the media-type
- referenced by the Content-Type header field.''')
- content_md5 = header_property('Content-MD5', doc='''
- The Content-MD5 entity-header field, as defined in RFC 1864, is an
- MD5 digest of the entity-body for the purpose of providing an
- end-to-end message integrity check (MIC) of the entity-body. (Note:
- a MIC is good for detecting accidental modification of the
- entity-body in transit, but is not proof against malicious attacks.)
- ''')
- date = header_property('Date', None, parse_date, http_date, doc='''
- The Date general-header field represents the date and time at which
- the message was originated, having the same semantics as orig-date
- in RFC 822.''')
- expires = header_property('Expires', None, parse_date, http_date, doc='''
- The Expires entity-header field gives the date/time after which the
- response is considered stale. A stale cache entry may not normally be
- returned by a cache.''')
- last_modified = header_property('Last-Modified', None, parse_date,
- http_date, doc='''
- The Last-Modified entity-header field indicates the date and time at
- which the origin server believes the variant was last modified.''')
-
- def _get_retry_after(self):
- value = self.headers.get('retry-after')
+ Age values are non-negative decimal integers, representing time
+ in seconds.""",
+ )
+ content_type = header_property(
+ "Content-Type",
+ doc="""The Content-Type entity-header field indicates the media
+ type of the entity-body sent to the recipient or, in the case of
+ the HEAD method, the media type that would have been sent had
+ the request been a GET.""",
+ )
+ content_length = header_property(
+ "Content-Length",
+ None,
+ int,
+ str,
+ doc="""The Content-Length entity-header field indicates the size
+ of the entity-body, in decimal number of OCTETs, sent to the
+ recipient or, in the case of the HEAD method, the size of the
+ entity-body that would have been sent had the request been a
+ GET.""",
+ )
+ content_location = header_property(
+ "Content-Location",
+ doc="""The Content-Location entity-header field MAY be used to
+ supply the resource location for the entity enclosed in the
+ message when that entity is accessible from a location separate
+ from the requested resource's URI.""",
+ )
+ content_encoding = header_property(
+ "Content-Encoding",
+ doc="""The Content-Encoding entity-header field is used as a
+ modifier to the media-type. When present, its value indicates
+ what additional content codings have been applied to the
+ entity-body, and thus what decoding mechanisms must be applied
+ in order to obtain the media-type referenced by the Content-Type
+ header field.""",
+ )
+ content_md5 = header_property(
+ "Content-MD5",
+ doc="""The Content-MD5 entity-header field, as defined in
+ RFC 1864, is an MD5 digest of the entity-body for the purpose of
+ providing an end-to-end message integrity check (MIC) of the
+ entity-body. (Note: a MIC is good for detecting accidental
+ modification of the entity-body in transit, but is not proof
+ against malicious attacks.)""",
+ )
+ date = header_property(
+ "Date",
+ None,
+ parse_date,
+ http_date,
+ doc="""The Date general-header field represents the date and
+ time at which the message was originated, having the same
+ semantics as orig-date in RFC 822.""",
+ )
+ expires = header_property(
+ "Expires",
+ None,
+ parse_date,
+ http_date,
+ doc="""The Expires entity-header field gives the date/time after
+ which the response is considered stale. A stale cache entry may
+ not normally be returned by a cache.""",
+ )
+ last_modified = header_property(
+ "Last-Modified",
+ None,
+ parse_date,
+ http_date,
+ doc="""The Last-Modified entity-header field indicates the date
+ and time at which the origin server believes the variant was
+ last modified.""",
+ )
+
+ @property
+ def retry_after(self):
+ """The Retry-After response-header field can be used with a
+ 503 (Service Unavailable) response to indicate how long the
+ service is expected to be unavailable to the requesting client.
+
+ Time in seconds until expiration or date.
+ """
+ value = self.headers.get("retry-after")
if value is None:
return
elif value.isdigit():
return datetime.utcnow() + timedelta(seconds=int(value))
return parse_date(value)
- def _set_retry_after(self, value):
+ @retry_after.setter
+ def retry_after(self, value):
if value is None:
- if 'retry-after' in self.headers:
- del self.headers['retry-after']
+ if "retry-after" in self.headers:
+ del self.headers["retry-after"]
return
elif isinstance(value, datetime):
value = http_date(value)
else:
value = str(value)
- self.headers['Retry-After'] = value
-
- retry_after = property(_get_retry_after, _set_retry_after, doc='''
- The Retry-After response-header field can be used with a 503 (Service
- Unavailable) response to indicate how long the service is expected
- to be unavailable to the requesting client.
-
- Time in seconds until expiration or date.''')
+ self.headers["Retry-After"] = value
- def _set_property(name, doc=None):
+ def _set_property(name, doc=None): # noqa: B902
def fget(self):
def on_update(header_set):
if not header_set and name in self.headers:
del self.headers[name]
elif header_set:
self.headers[name] = header_set.to_header()
+
return parse_set_header(self.headers.get(name), on_update)
def fset(self, value):
@@ -230,24 +292,31 @@ class CommonResponseDescriptorsMixin(object):
self.headers[name] = value
else:
self.headers[name] = dump_header(value)
+
return property(fget, fset, doc=doc)
- vary = _set_property('Vary', doc='''
- The Vary field value indicates the set of request-header fields that
- fully determines, while the response is fresh, whether a cache is
- permitted to use the response to reply to a subsequent request
- without revalidation.''')
- content_language = _set_property('Content-Language', doc='''
- The Content-Language entity-header field describes the natural
- language(s) of the intended audience for the enclosed entity. Note
- that this might not be equivalent to all the languages used within
- the entity-body.''')
- allow = _set_property('Allow', doc='''
- The Allow entity-header field lists the set of methods supported
- by the resource identified by the Request-URI. The purpose of this
- field is strictly to inform the recipient of valid methods
- associated with the resource. An Allow header field MUST be
- present in a 405 (Method Not Allowed) response.''')
-
- del _set_property, _get_mimetype, _set_mimetype, _get_retry_after, \
- _set_retry_after
+ vary = _set_property(
+ "Vary",
+ doc="""The Vary field value indicates the set of request-header
+ fields that fully determines, while the response is fresh,
+ whether a cache is permitted to use the response to reply to a
+ subsequent request without revalidation.""",
+ )
+ content_language = _set_property(
+ "Content-Language",
+ doc="""The Content-Language entity-header field describes the
+ natural language(s) of the intended audience for the enclosed
+ entity. Note that this might not be equivalent to all the
+ languages used within the entity-body.""",
+ )
+ allow = _set_property(
+ "Allow",
+ doc="""The Allow entity-header field lists the set of methods
+ supported by the resource identified by the Request-URI. The
+ purpose of this field is strictly to inform the recipient of
+ valid methods associated with the resource. An Allow header
+ field MUST be present in a 405 (Method Not Allowed)
+ response.""",
+ )
+
+ del _set_property
diff --git a/src/werkzeug/wrappers/etag.py b/src/werkzeug/wrappers/etag.py
index 7ad12f61..0733506f 100644
--- a/src/werkzeug/wrappers/etag.py
+++ b/src/werkzeug/wrappers/etag.py
@@ -31,9 +31,8 @@ class ETagRequestMixin(object):
"""A :class:`~werkzeug.datastructures.RequestCacheControl` object
for the incoming cache control headers.
"""
- cache_control = self.environ.get('HTTP_CACHE_CONTROL')
- return parse_cache_control_header(cache_control, None,
- RequestCacheControl)
+ cache_control = self.environ.get("HTTP_CACHE_CONTROL")
+ return parse_cache_control_header(cache_control, None, RequestCacheControl)
@cached_property
def if_match(self):
@@ -41,7 +40,7 @@ class ETagRequestMixin(object):
:rtype: :class:`~werkzeug.datastructures.ETags`
"""
- return parse_etags(self.environ.get('HTTP_IF_MATCH'))
+ return parse_etags(self.environ.get("HTTP_IF_MATCH"))
@cached_property
def if_none_match(self):
@@ -49,17 +48,17 @@ class ETagRequestMixin(object):
:rtype: :class:`~werkzeug.datastructures.ETags`
"""
- return parse_etags(self.environ.get('HTTP_IF_NONE_MATCH'))
+ return parse_etags(self.environ.get("HTTP_IF_NONE_MATCH"))
@cached_property
def if_modified_since(self):
"""The parsed `If-Modified-Since` header as datetime object."""
- return parse_date(self.environ.get('HTTP_IF_MODIFIED_SINCE'))
+ return parse_date(self.environ.get("HTTP_IF_MODIFIED_SINCE"))
@cached_property
def if_unmodified_since(self):
"""The parsed `If-Unmodified-Since` header as datetime object."""
- return parse_date(self.environ.get('HTTP_IF_UNMODIFIED_SINCE'))
+ return parse_date(self.environ.get("HTTP_IF_UNMODIFIED_SINCE"))
@cached_property
def if_range(self):
@@ -69,7 +68,7 @@ class ETagRequestMixin(object):
:rtype: :class:`~werkzeug.datastructures.IfRange`
"""
- return parse_if_range_header(self.environ.get('HTTP_IF_RANGE'))
+ return parse_if_range_header(self.environ.get("HTTP_IF_RANGE"))
@cached_property
def range(self):
@@ -79,7 +78,7 @@ class ETagRequestMixin(object):
:rtype: :class:`~werkzeug.datastructures.Range`
"""
- return parse_range_header(self.environ.get('HTTP_RANGE'))
+ return parse_range_header(self.environ.get("HTTP_RANGE"))
class ETagResponseMixin(object):
@@ -99,14 +98,16 @@ class ETagResponseMixin(object):
directives that MUST be obeyed by all caching mechanisms along the
request/response chain.
"""
+
def on_update(cache_control):
- if not cache_control and 'cache-control' in self.headers:
- del self.headers['cache-control']
+ if not cache_control and "cache-control" in self.headers:
+ del self.headers["cache-control"]
elif cache_control:
- self.headers['Cache-Control'] = cache_control.to_header()
- return parse_cache_control_header(self.headers.get('cache-control'),
- on_update,
- ResponseCacheControl)
+ self.headers["Cache-Control"] = cache_control.to_header()
+
+ return parse_cache_control_header(
+ self.headers.get("cache-control"), on_update, ResponseCacheControl
+ )
def _wrap_response(self, start, length):
"""Wrap existing Response in case of Range Request context."""
@@ -118,12 +119,15 @@ class ETagResponseMixin(object):
resource is considered unchanged when compared with `If-Range` header.
"""
return (
- 'HTTP_IF_RANGE' not in environ
+ "HTTP_IF_RANGE" not in environ
or not is_resource_modified(
- environ, self.headers.get('etag'), None,
- self.headers.get('last-modified'), ignore_if_range=False
+ environ,
+ self.headers.get("etag"),
+ None,
+ self.headers.get("last-modified"),
+ ignore_if_range=False,
)
- ) and 'HTTP_RANGE' in environ
+ ) and "HTTP_RANGE" in environ
def _process_range_request(self, environ, complete_length=None, accept_ranges=None):
"""Handle Range Request related headers (RFC7233). If `Accept-Ranges`
@@ -136,13 +140,14 @@ class ETagResponseMixin(object):
:raises: :class:`~werkzeug.exceptions.RequestedRangeNotSatisfiable`
if `Range` header could not be parsed or satisfied.
"""
- from werkzeug.exceptions import RequestedRangeNotSatisfiable
+ from ..exceptions import RequestedRangeNotSatisfiable
+
if accept_ranges is None:
return False
- self.headers['Accept-Ranges'] = accept_ranges
+ self.headers["Accept-Ranges"] = accept_ranges
if not self._is_range_request_processable(environ) or complete_length is None:
return False
- parsed_range = parse_range_header(environ.get('HTTP_RANGE'))
+ parsed_range = parse_range_header(environ.get("HTTP_RANGE"))
if parsed_range is None:
raise RequestedRangeNotSatisfiable(complete_length)
range_tuple = parsed_range.range_for_length(complete_length)
@@ -153,15 +158,16 @@ class ETagResponseMixin(object):
# Be sure not to send 206 response
# if requested range is the full content.
if content_length != complete_length:
- self.headers['Content-Length'] = content_length
+ self.headers["Content-Length"] = content_length
self.content_range = content_range_header
self.status_code = 206
self._wrap_response(range_tuple[0], content_length)
return True
return False
- def make_conditional(self, request_or_environ, accept_ranges=False,
- complete_length=None):
+ def make_conditional(
+ self, request_or_environ, accept_ranges=False, complete_length=None
+ ):
"""Make the response conditional to the request. This method works
best if an etag was defined for the response already. The `add_etag`
method can be used to do that. If called without etag just the date
@@ -199,43 +205,48 @@ class ETagResponseMixin(object):
if `Range` header could not be parsed or satisfied.
"""
environ = _get_environ(request_or_environ)
- if environ['REQUEST_METHOD'] in ('GET', 'HEAD'):
+ if environ["REQUEST_METHOD"] in ("GET", "HEAD"):
# if the date is not in the headers, add it now. We however
# will not override an already existing header. Unfortunately
# this header will be overriden by many WSGI servers including
# wsgiref.
- if 'date' not in self.headers:
- self.headers['Date'] = http_date()
+ if "date" not in self.headers:
+ self.headers["Date"] = http_date()
accept_ranges = _clean_accept_ranges(accept_ranges)
is206 = self._process_range_request(environ, complete_length, accept_ranges)
if not is206 and not is_resource_modified(
- environ, self.headers.get('etag'), None,
- self.headers.get('last-modified')
+ environ,
+ self.headers.get("etag"),
+ None,
+ self.headers.get("last-modified"),
):
- if parse_etags(environ.get('HTTP_IF_MATCH')):
+ if parse_etags(environ.get("HTTP_IF_MATCH")):
self.status_code = 412
else:
self.status_code = 304
- if self.automatically_set_content_length and 'content-length' not in self.headers:
+ if (
+ self.automatically_set_content_length
+ and "content-length" not in self.headers
+ ):
length = self.calculate_content_length()
if length is not None:
- self.headers['Content-Length'] = length
+ self.headers["Content-Length"] = length
return self
def add_etag(self, overwrite=False, weak=False):
"""Add an etag for the current response if there is none yet."""
- if overwrite or 'etag' not in self.headers:
+ if overwrite or "etag" not in self.headers:
self.set_etag(generate_etag(self.get_data()), weak)
def set_etag(self, etag, weak=False):
"""Set the etag, and override the old one if there was one."""
- self.headers['ETag'] = quote_etag(etag, weak)
+ self.headers["ETag"] = quote_etag(etag, weak)
def get_etag(self):
"""Return a tuple in the form ``(etag, is_weak)``. If there is no
ETag the return value is ``(None, None)``.
"""
- return unquote_etag(self.headers.get('ETag'))
+ return unquote_etag(self.headers.get("ETag"))
def freeze(self, no_etag=False):
"""Call this method if you want to make your response object ready for
@@ -246,22 +257,25 @@ class ETagResponseMixin(object):
self.add_etag()
super(ETagResponseMixin, self).freeze()
- accept_ranges = header_property('Accept-Ranges', doc='''
- The `Accept-Ranges` header. Even though the name would indicate
- that multiple values are supported, it must be one string token only.
+ accept_ranges = header_property(
+ "Accept-Ranges",
+ doc="""The `Accept-Ranges` header. Even though the name would
+ indicate that multiple values are supported, it must be one
+ string token only.
The values ``'bytes'`` and ``'none'`` are common.
- .. versionadded:: 0.7''')
+ .. versionadded:: 0.7""",
+ )
def _get_content_range(self):
def on_update(rng):
if not rng:
- del self.headers['content-range']
+ del self.headers["content-range"]
else:
- self.headers['Content-Range'] = rng.to_header()
- rv = parse_content_range_header(self.headers.get('content-range'),
- on_update)
+ self.headers["Content-Range"] = rng.to_header()
+
+ rv = parse_content_range_header(self.headers.get("content-range"), on_update)
# always provide a content range object to make the descriptor
# more user friendly. It provides an unset() method that can be
# used to remove the header quickly.
@@ -271,16 +285,20 @@ class ETagResponseMixin(object):
def _set_content_range(self, value):
if not value:
- del self.headers['content-range']
+ del self.headers["content-range"]
elif isinstance(value, string_types):
- self.headers['Content-Range'] = value
+ self.headers["Content-Range"] = value
else:
- self.headers['Content-Range'] = value.to_header()
- content_range = property(_get_content_range, _set_content_range, doc='''
- The `Content-Range` header as
- :class:`~werkzeug.datastructures.ContentRange` object. Even if the
- header is not set it wil provide such an object for easier
+ self.headers["Content-Range"] = value.to_header()
+
+ content_range = property(
+ _get_content_range,
+ _set_content_range,
+ doc="""The ``Content-Range`` header as
+ :class:`~werkzeug.datastructures.ContentRange` object. Even if
+ the header is not set it wil provide such an object for easier
manipulation.
- .. versionadded:: 0.7''')
+ .. versionadded:: 0.7""",
+ )
del _get_content_range, _set_content_range
diff --git a/src/werkzeug/wrappers/json.py b/src/werkzeug/wrappers/json.py
index 851806dd..6d5dc33d 100644
--- a/src/werkzeug/wrappers/json.py
+++ b/src/werkzeug/wrappers/json.py
@@ -75,9 +75,9 @@ class JSONMixin(object):
"""
mt = self.mimetype
return (
- mt == 'application/json'
- or mt.startswith('application/')
- and mt.endswith('+json')
+ mt == "application/json"
+ or mt.startswith("application/")
+ and mt.endswith("+json")
)
def _get_data_for_json(self, cache):
@@ -142,4 +142,4 @@ class JSONMixin(object):
for :meth:`get_json`. The default implementation raises
:exc:`~werkzeug.exceptions.BadRequest`.
"""
- raise BadRequest('Failed to decode JSON object: {0}'.format(e))
+ raise BadRequest("Failed to decode JSON object: {0}".format(e))
diff --git a/src/werkzeug/wrappers/request.py b/src/werkzeug/wrappers/request.py
index 5f1b00c8..d1c71b64 100644
--- a/src/werkzeug/wrappers/request.py
+++ b/src/werkzeug/wrappers/request.py
@@ -6,9 +6,14 @@ from .etag import ETagRequestMixin
from .user_agent import UserAgentMixin
-class Request(BaseRequest, AcceptMixin, ETagRequestMixin,
- UserAgentMixin, AuthorizationMixin,
- CommonRequestDescriptorsMixin):
+class Request(
+ BaseRequest,
+ AcceptMixin,
+ ETagRequestMixin,
+ UserAgentMixin,
+ AuthorizationMixin,
+ CommonRequestDescriptorsMixin,
+):
"""Full featured request object implementing the following mixins:
- :class:`AcceptMixin` for accept header parsing
diff --git a/src/werkzeug/wrappers/response.py b/src/werkzeug/wrappers/response.py
index 8b7dc7ba..cd86cacd 100644
--- a/src/werkzeug/wrappers/response.py
+++ b/src/werkzeug/wrappers/response.py
@@ -11,7 +11,7 @@ class ResponseStream(object):
iterable of the response object.
"""
- mode = 'wb+'
+ mode = "wb+"
def __init__(self, response):
self.response = response
@@ -19,10 +19,10 @@ class ResponseStream(object):
def write(self, value):
if self.closed:
- raise ValueError('I/O operation on closed file')
+ raise ValueError("I/O operation on closed file")
self.response._ensure_sequence(mutable=True)
self.response.response.append(value)
- self.response.headers.pop('Content-Length', None)
+ self.response.headers.pop("Content-Length", None)
return len(value)
def writelines(self, seq):
@@ -34,11 +34,11 @@ class ResponseStream(object):
def flush(self):
if self.closed:
- raise ValueError('I/O operation on closed file')
+ raise ValueError("I/O operation on closed file")
def isatty(self):
if self.closed:
- raise ValueError('I/O operation on closed file')
+ raise ValueError("I/O operation on closed file")
return False
def tell(self):
@@ -62,9 +62,13 @@ class ResponseStreamMixin(object):
return ResponseStream(self)
-class Response(BaseResponse, ETagResponseMixin, ResponseStreamMixin,
- CommonResponseDescriptorsMixin,
- WWWAuthenticateMixin):
+class Response(
+ BaseResponse,
+ ETagResponseMixin,
+ ResponseStreamMixin,
+ CommonResponseDescriptorsMixin,
+ WWWAuthenticateMixin,
+):
"""Full featured response object implementing the following mixins:
- :class:`ETagResponseMixin` for etag and cache control handling
diff --git a/src/werkzeug/wrappers/user_agent.py b/src/werkzeug/wrappers/user_agent.py
index be106efe..72588dd9 100644
--- a/src/werkzeug/wrappers/user_agent.py
+++ b/src/werkzeug/wrappers/user_agent.py
@@ -2,13 +2,14 @@ from ..utils import cached_property
class UserAgentMixin(object):
- """Adds a `user_agent` attribute to the request object which contains the
- parsed user agent of the browser that triggered the request as a
- :class:`~werkzeug.useragents.UserAgent` object.
+ """Adds a `user_agent` attribute to the request object which
+ contains the parsed user agent of the browser that triggered the
+ request as a :class:`~werkzeug.useragents.UserAgent` object.
"""
@cached_property
def user_agent(self):
"""The current user agent."""
- from werkzeug.useragents import UserAgent
+ from ..useragents import UserAgent
+
return UserAgent(self.environ)
diff --git a/src/werkzeug/wsgi.py b/src/werkzeug/wsgi.py
index 60966503..f069f2d8 100644
--- a/src/werkzeug/wsgi.py
+++ b/src/werkzeug/wsgi.py
@@ -11,14 +11,24 @@
import io
import re
import warnings
-from functools import partial, update_wrapper
+from functools import partial
+from functools import update_wrapper
from itertools import chain
-from werkzeug._compat import BytesIO, implements_iterator, \
- make_literal_wrapper, string_types, text_type, to_bytes, to_unicode, \
- try_coerce_native, wsgi_get_bytes
-from werkzeug._internal import _encode_idna
-from werkzeug.urls import uri_to_iri, url_join, url_parse, url_quote
+from ._compat import BytesIO
+from ._compat import implements_iterator
+from ._compat import make_literal_wrapper
+from ._compat import string_types
+from ._compat import text_type
+from ._compat import to_bytes
+from ._compat import to_unicode
+from ._compat import try_coerce_native
+from ._compat import wsgi_get_bytes
+from ._internal import _encode_idna
+from .urls import uri_to_iri
+from .urls import url_join
+from .urls import url_parse
+from .urls import url_quote
def responder(f):
@@ -34,8 +44,13 @@ def responder(f):
return update_wrapper(lambda *a: f(*a)(*a[-2:]), f)
-def get_current_url(environ, root_only=False, strip_querystring=False,
- host_only=False, trusted_hosts=None):
+def get_current_url(
+ environ,
+ root_only=False,
+ strip_querystring=False,
+ host_only=False,
+ trusted_hosts=None,
+):
"""A handy helper function that recreates the full URL as IRI for the
current request or parts of it. Here's an example:
@@ -70,19 +85,19 @@ def get_current_url(environ, root_only=False, strip_querystring=False,
:param trusted_hosts: a list of trusted hosts, see :func:`host_is_trusted`
for more information.
"""
- tmp = [environ['wsgi.url_scheme'], '://', get_host(environ, trusted_hosts)]
+ tmp = [environ["wsgi.url_scheme"], "://", get_host(environ, trusted_hosts)]
cat = tmp.append
if host_only:
- return uri_to_iri(''.join(tmp) + '/')
- cat(url_quote(wsgi_get_bytes(environ.get('SCRIPT_NAME', ''))).rstrip('/'))
- cat('/')
+ return uri_to_iri("".join(tmp) + "/")
+ cat(url_quote(wsgi_get_bytes(environ.get("SCRIPT_NAME", ""))).rstrip("/"))
+ cat("/")
if not root_only:
- cat(url_quote(wsgi_get_bytes(environ.get('PATH_INFO', '')).lstrip(b'/')))
+ cat(url_quote(wsgi_get_bytes(environ.get("PATH_INFO", "")).lstrip(b"/")))
if not strip_querystring:
qs = get_query_string(environ)
if qs:
- cat('?' + qs)
- return uri_to_iri(''.join(tmp))
+ cat("?" + qs)
+ return uri_to_iri("".join(tmp))
def host_is_trusted(hostname, trusted_list):
@@ -103,8 +118,8 @@ def host_is_trusted(hostname, trusted_list):
trusted_list = [trusted_list]
def _normalize(hostname):
- if ':' in hostname:
- hostname = hostname.rsplit(':', 1)[0]
+ if ":" in hostname:
+ hostname = hostname.rsplit(":", 1)[0]
return _encode_idna(hostname)
try:
@@ -112,7 +127,7 @@ def host_is_trusted(hostname, trusted_list):
except UnicodeError:
return False
for ref in trusted_list:
- if ref.startswith('.'):
+ if ref.startswith("."):
ref = ref[1:]
suffix_match = True
else:
@@ -123,7 +138,7 @@ def host_is_trusted(hostname, trusted_list):
return False
if ref == hostname:
return True
- if suffix_match and hostname.endswith(b'.' + ref):
+ if suffix_match and hostname.endswith(b"." + ref):
return True
return False
@@ -144,20 +159,23 @@ def get_host(environ, trusted_hosts=None):
:raise ~werkzeug.exceptions.SecurityError: If the host is not
trusted.
"""
- if 'HTTP_HOST' in environ:
- rv = environ['HTTP_HOST']
- if environ['wsgi.url_scheme'] == 'http' and rv.endswith(':80'):
+ if "HTTP_HOST" in environ:
+ rv = environ["HTTP_HOST"]
+ if environ["wsgi.url_scheme"] == "http" and rv.endswith(":80"):
rv = rv[:-3]
- elif environ['wsgi.url_scheme'] == 'https' and rv.endswith(':443'):
+ elif environ["wsgi.url_scheme"] == "https" and rv.endswith(":443"):
rv = rv[:-4]
else:
- rv = environ['SERVER_NAME']
- if (environ['wsgi.url_scheme'], environ['SERVER_PORT']) not \
- in (('https', '443'), ('http', '80')):
- rv += ':' + environ['SERVER_PORT']
+ rv = environ["SERVER_NAME"]
+ if (environ["wsgi.url_scheme"], environ["SERVER_PORT"]) not in (
+ ("https", "443"),
+ ("http", "80"),
+ ):
+ rv += ":" + environ["SERVER_PORT"]
if trusted_hosts is not None:
if not host_is_trusted(rv, trusted_hosts):
- from werkzeug.exceptions import SecurityError
+ from .exceptions import SecurityError
+
raise SecurityError('Host "%s" is not trusted' % rv)
return rv
@@ -171,10 +189,10 @@ def get_content_length(environ):
:param environ: the WSGI environ to fetch the content length from.
"""
- if environ.get('HTTP_TRANSFER_ENCODING', '') == 'chunked':
+ if environ.get("HTTP_TRANSFER_ENCODING", "") == "chunked":
return None
- content_length = environ.get('CONTENT_LENGTH')
+ content_length = environ.get("CONTENT_LENGTH")
if content_length is not None:
try:
return max(0, int(content_length))
@@ -199,13 +217,13 @@ def get_input_stream(environ, safe_fallback=True):
content length is not set. Disabling this allows infinite streams,
which can be a denial-of-service risk.
"""
- stream = environ['wsgi.input']
+ stream = environ["wsgi.input"]
content_length = get_content_length(environ)
# A wsgi extension that tells us if the input is terminated. In
# that case we return the stream unchanged as we know we can safely
# read it until the end.
- if environ.get('wsgi.input_terminated'):
+ if environ.get("wsgi.input_terminated"):
return stream
# If the request doesn't specify a content length, returning the stream is
@@ -228,14 +246,14 @@ def get_query_string(environ):
:param environ: the WSGI environment object to get the query string from.
"""
- qs = wsgi_get_bytes(environ.get('QUERY_STRING', ''))
+ qs = wsgi_get_bytes(environ.get("QUERY_STRING", ""))
# QUERY_STRING really should be ascii safe but some browsers
# will send us some unicode stuff (I am looking at you IE).
# In that case we want to urllib quote it badly.
- return try_coerce_native(url_quote(qs, safe=':&%=+$!*\'(),'))
+ return try_coerce_native(url_quote(qs, safe=":&%=+$!*'(),"))
-def get_path_info(environ, charset='utf-8', errors='replace'):
+def get_path_info(environ, charset="utf-8", errors="replace"):
"""Returns the `PATH_INFO` from the WSGI environment and properly
decodes it. This also takes care about the WSGI decoding dance
on Python 3 environments. if the `charset` is set to `None` a
@@ -248,11 +266,11 @@ def get_path_info(environ, charset='utf-8', errors='replace'):
decoding should be performed.
:param errors: the decoding error handling.
"""
- path = wsgi_get_bytes(environ.get('PATH_INFO', ''))
+ path = wsgi_get_bytes(environ.get("PATH_INFO", ""))
return to_unicode(path, charset, errors, allow_none_charset=True)
-def get_script_name(environ, charset='utf-8', errors='replace'):
+def get_script_name(environ, charset="utf-8", errors="replace"):
"""Returns the `SCRIPT_NAME` from the WSGI environment and properly
decodes it. This also takes care about the WSGI decoding dance
on Python 3 environments. if the `charset` is set to `None` a
@@ -265,11 +283,11 @@ def get_script_name(environ, charset='utf-8', errors='replace'):
decoding should be performed.
:param errors: the decoding error handling.
"""
- path = wsgi_get_bytes(environ.get('SCRIPT_NAME', ''))
+ path = wsgi_get_bytes(environ.get("SCRIPT_NAME", ""))
return to_unicode(path, charset, errors, allow_none_charset=True)
-def pop_path_info(environ, charset='utf-8', errors='replace'):
+def pop_path_info(environ, charset="utf-8", errors="replace"):
"""Removes and returns the next segment of `PATH_INFO`, pushing it onto
`SCRIPT_NAME`. Returns `None` if there is nothing left on `PATH_INFO`.
@@ -296,32 +314,32 @@ def pop_path_info(environ, charset='utf-8', errors='replace'):
:param environ: the WSGI environment that is modified.
"""
- path = environ.get('PATH_INFO')
+ path = environ.get("PATH_INFO")
if not path:
return None
- script_name = environ.get('SCRIPT_NAME', '')
+ script_name = environ.get("SCRIPT_NAME", "")
# shift multiple leading slashes over
old_path = path
- path = path.lstrip('/')
+ path = path.lstrip("/")
if path != old_path:
- script_name += '/' * (len(old_path) - len(path))
+ script_name += "/" * (len(old_path) - len(path))
- if '/' not in path:
- environ['PATH_INFO'] = ''
- environ['SCRIPT_NAME'] = script_name + path
+ if "/" not in path:
+ environ["PATH_INFO"] = ""
+ environ["SCRIPT_NAME"] = script_name + path
rv = wsgi_get_bytes(path)
else:
- segment, path = path.split('/', 1)
- environ['PATH_INFO'] = '/' + path
- environ['SCRIPT_NAME'] = script_name + segment
+ segment, path = path.split("/", 1)
+ environ["PATH_INFO"] = "/" + path
+ environ["SCRIPT_NAME"] = script_name + segment
rv = wsgi_get_bytes(segment)
return to_unicode(rv, charset, errors, allow_none_charset=True)
-def peek_path_info(environ, charset='utf-8', errors='replace'):
+def peek_path_info(environ, charset="utf-8", errors="replace"):
"""Returns the next segment on the `PATH_INFO` or `None` if there
is none. Works like :func:`pop_path_info` without modifying the
environment:
@@ -342,18 +360,19 @@ def peek_path_info(environ, charset='utf-8', errors='replace'):
:param environ: the WSGI environment that is checked.
"""
- segments = environ.get('PATH_INFO', '').lstrip('/').split('/', 1)
+ segments = environ.get("PATH_INFO", "").lstrip("/").split("/", 1)
if segments:
- return to_unicode(wsgi_get_bytes(segments[0]),
- charset, errors, allow_none_charset=True)
+ return to_unicode(
+ wsgi_get_bytes(segments[0]), charset, errors, allow_none_charset=True
+ )
def extract_path_info(
environ_or_baseurl,
path_or_url,
- charset='utf-8',
- errors='werkzeug.url_quote',
- collapse_http_schemes=True
+ charset="utf-8",
+ errors="werkzeug.url_quote",
+ collapse_http_schemes=True,
):
"""Extracts the path info from the given URL (or WSGI environment) and
path. The path info returned is a unicode string, not a bytestring
@@ -395,29 +414,29 @@ def extract_path_info(
.. versionadded:: 0.6
"""
+
def _normalize_netloc(scheme, netloc):
- parts = netloc.split(u'@', 1)[-1].split(u':', 1)
+ parts = netloc.split(u"@", 1)[-1].split(u":", 1)
if len(parts) == 2:
netloc, port = parts
- if (scheme == u'http' and port == u'80') or \
- (scheme == u'https' and port == u'443'):
+ if (scheme == u"http" and port == u"80") or (
+ scheme == u"https" and port == u"443"
+ ):
port = None
else:
netloc = parts[0]
port = None
if port is not None:
- netloc += u':' + port
+ netloc += u":" + port
return netloc
# make sure whatever we are working on is a IRI and parse it
path = uri_to_iri(path_or_url, charset, errors)
if isinstance(environ_or_baseurl, dict):
- environ_or_baseurl = get_current_url(environ_or_baseurl,
- root_only=True)
+ environ_or_baseurl = get_current_url(environ_or_baseurl, root_only=True)
base_iri = uri_to_iri(environ_or_baseurl, charset, errors)
base_scheme, base_netloc, base_path = url_parse(base_iri)[:3]
- cur_scheme, cur_netloc, cur_path, = \
- url_parse(url_join(base_iri, path))[:3]
+ cur_scheme, cur_netloc, cur_path, = url_parse(url_join(base_iri, path))[:3]
# normalize the network location
base_netloc = _normalize_netloc(base_scheme, base_netloc)
@@ -426,11 +445,10 @@ def extract_path_info(
# is that IRI even on a known HTTP scheme?
if collapse_http_schemes:
for scheme in base_scheme, cur_scheme:
- if scheme not in (u'http', u'https'):
+ if scheme not in (u"http", u"https"):
return None
else:
- if not (base_scheme in (u'http', u'https')
- and base_scheme == cur_scheme):
+ if not (base_scheme in (u"http", u"https") and base_scheme == cur_scheme):
return None
# are the netlocs compatible?
@@ -438,16 +456,15 @@ def extract_path_info(
return None
# are we below the application path?
- base_path = base_path.rstrip(u'/')
+ base_path = base_path.rstrip(u"/")
if not cur_path.startswith(base_path):
return None
- return u'/' + cur_path[len(base_path):].lstrip(u'/')
+ return u"/" + cur_path[len(base_path) :].lstrip(u"/")
@implements_iterator
class ClosingIterator(object):
-
"""The WSGI specification requires that all middlewares and gateways
respect the `close` callback of the iterable returned by the application.
Because it is useful to add another close action to a returned iterable
@@ -478,7 +495,7 @@ class ClosingIterator(object):
callbacks = [callbacks]
else:
callbacks = list(callbacks)
- iterable_close = getattr(iterable, 'close', None)
+ iterable_close = getattr(iterable, "close", None)
if iterable_close:
callbacks.insert(0, iterable_close)
self._callbacks = callbacks
@@ -510,12 +527,11 @@ def wrap_file(environ, file, buffer_size=8192):
:param file: a :class:`file`-like object with a :meth:`~file.read` method.
:param buffer_size: number of bytes for one iteration.
"""
- return environ.get('wsgi.file_wrapper', FileWrapper)(file, buffer_size)
+ return environ.get("wsgi.file_wrapper", FileWrapper)(file, buffer_size)
@implements_iterator
class FileWrapper(object):
-
"""This class can be used to convert a :class:`file`-like object into
an iterable. It yields `buffer_size` blocks until the file is fully
read.
@@ -538,22 +554,22 @@ class FileWrapper(object):
self.buffer_size = buffer_size
def close(self):
- if hasattr(self.file, 'close'):
+ if hasattr(self.file, "close"):
self.file.close()
def seekable(self):
- if hasattr(self.file, 'seekable'):
+ if hasattr(self.file, "seekable"):
return self.file.seekable()
- if hasattr(self.file, 'seek'):
+ if hasattr(self.file, "seek"):
return True
return False
def seek(self, *args):
- if hasattr(self.file, 'seek'):
+ if hasattr(self.file, "seek"):
self.file.seek(*args)
def tell(self):
- if hasattr(self.file, 'tell'):
+ if hasattr(self.file, "tell"):
return self.file.tell()
return None
@@ -593,7 +609,7 @@ class _RangeWrapper(object):
if byte_range is not None:
self.end_byte = self.start_byte + self.byte_range
self.read_length = 0
- self.seekable = hasattr(iterable, 'seekable') and iterable.seekable()
+ self.seekable = hasattr(iterable, "seekable") and iterable.seekable()
self.end_reached = False
def __iter__(self):
@@ -618,7 +634,7 @@ class _RangeWrapper(object):
while self.read_length <= self.start_byte:
chunk = self._next_chunk()
if chunk is not None:
- chunk = chunk[self.start_byte - self.read_length:]
+ chunk = chunk[self.start_byte - self.read_length :]
contextual_read_length = self.start_byte
return chunk, contextual_read_length
@@ -633,7 +649,7 @@ class _RangeWrapper(object):
chunk = self._next_chunk()
if self.end_byte is not None and self.read_length >= self.end_byte:
self.end_reached = True
- return chunk[:self.end_byte - contextual_read_length]
+ return chunk[: self.end_byte - contextual_read_length]
return chunk
def __next__(self):
@@ -644,16 +660,17 @@ class _RangeWrapper(object):
raise StopIteration()
def close(self):
- if hasattr(self.iterable, 'close'):
+ if hasattr(self.iterable, "close"):
self.iterable.close()
def _make_chunk_iter(stream, limit, buffer_size):
"""Helper for the line and chunk iter functions."""
if isinstance(stream, (bytes, bytearray, text_type)):
- raise TypeError('Passed a string or byte object instead of '
- 'true iterator or stream.')
- if not hasattr(stream, 'read'):
+ raise TypeError(
+ "Passed a string or byte object instead of true iterator or stream."
+ )
+ if not hasattr(stream, "read"):
for item in stream:
if item:
yield item
@@ -668,8 +685,7 @@ def _make_chunk_iter(stream, limit, buffer_size):
yield item
-def make_line_iter(stream, limit=None, buffer_size=10 * 1024,
- cap_at_buffer=False):
+def make_line_iter(stream, limit=None, buffer_size=10 * 1024, cap_at_buffer=False):
"""Safely iterates line-based over an input stream. If the input stream
is not a :class:`LimitedStream` the `limit` parameter is mandatory.
@@ -703,15 +719,15 @@ def make_line_iter(stream, limit=None, buffer_size=10 * 1024,
"""
_iter = _make_chunk_iter(stream, limit, buffer_size)
- first_item = next(_iter, '')
+ first_item = next(_iter, "")
if not first_item:
return
s = make_literal_wrapper(first_item)
- empty = s('')
- cr = s('\r')
- lf = s('\n')
- crlf = s('\r\n')
+ empty = s("")
+ cr = s("\r")
+ lf = s("\n")
+ crlf = s("\r\n")
_iter = chain((first_item,), _iter)
@@ -719,7 +735,7 @@ def make_line_iter(stream, limit=None, buffer_size=10 * 1024,
_join = empty.join
buffer = []
while 1:
- new_data = next(_iter, '')
+ new_data = next(_iter, "")
if not new_data:
break
new_buf = []
@@ -754,8 +770,9 @@ def make_line_iter(stream, limit=None, buffer_size=10 * 1024,
yield previous
-def make_chunk_iter(stream, separator, limit=None, buffer_size=10 * 1024,
- cap_at_buffer=False):
+def make_chunk_iter(
+ stream, separator, limit=None, buffer_size=10 * 1024, cap_at_buffer=False
+):
"""Works like :func:`make_line_iter` but accepts a separator
which divides chunks. If you want newline based processing
you should use :func:`make_line_iter` instead as it
@@ -782,23 +799,23 @@ def make_chunk_iter(stream, separator, limit=None, buffer_size=10 * 1024,
"""
_iter = _make_chunk_iter(stream, limit, buffer_size)
- first_item = next(_iter, '')
+ first_item = next(_iter, "")
if not first_item:
return
_iter = chain((first_item,), _iter)
if isinstance(first_item, text_type):
separator = to_unicode(separator)
- _split = re.compile(r'(%s)' % re.escape(separator)).split
- _join = u''.join
+ _split = re.compile(r"(%s)" % re.escape(separator)).split
+ _join = u"".join
else:
separator = to_bytes(separator)
- _split = re.compile(b'(' + re.escape(separator) + b')').split
- _join = b''.join
+ _split = re.compile(b"(" + re.escape(separator) + b")").split
+ _join = b"".join
buffer = []
while 1:
- new_data = next(_iter, '')
+ new_data = next(_iter, "")
if not new_data:
break
chunks = _split(new_data)
@@ -828,7 +845,6 @@ def make_chunk_iter(stream, separator, limit=None, buffer_size=10 * 1024,
@implements_iterator
class LimitedStream(io.IOBase):
-
"""Wraps a stream so that it doesn't read more than n bytes. If the
stream is exhausted and the caller tries to get more bytes from it
:func:`on_exhausted` is called which by default returns an empty
@@ -891,7 +907,8 @@ class LimitedStream(io.IOBase):
the client went away. By default a
:exc:`~werkzeug.exceptions.ClientDisconnected` exception is raised.
"""
- from werkzeug.exceptions import ClientDisconnected
+ from .exceptions import ClientDisconnected
+
raise ClientDisconnected()
def exhaust(self, chunk_size=1024 * 64):
@@ -984,9 +1001,10 @@ class LimitedStream(io.IOBase):
return True
-from werkzeug.middleware.dispatcher import DispatcherMiddleware as _DispatcherMiddleware
-from werkzeug.middleware.http_proxy import ProxyMiddleware as _ProxyMiddleware
-from werkzeug.middleware.shared_data import SharedDataMiddleware as _SharedDataMiddleware
+# DEPRECATED
+from .middleware.dispatcher import DispatcherMiddleware as _DispatcherMiddleware
+from .middleware.http_proxy import ProxyMiddleware as _ProxyMiddleware
+from .middleware.shared_data import SharedDataMiddleware as _SharedDataMiddleware
class ProxyMiddleware(_ProxyMiddleware):
diff --git a/tests/conftest.py b/tests/conftest.py
index 0a76d337..1ce4fd53 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -8,17 +8,15 @@
"""
from __future__ import print_function
-from itertools import count
-
import logging
import os
import platform
import signal
-import sys
-
import subprocess
+import sys
import textwrap
import time
+from itertools import count
import pytest
@@ -28,11 +26,12 @@ from werkzeug.urls import url_quote
from werkzeug.utils import cached_property
try:
- __import__('pytest_xprocess')
+ __import__("pytest_xprocess")
except ImportError:
- @pytest.fixture(scope='session')
+
+ @pytest.fixture(scope="session")
def xprocess():
- pytest.skip('pytest-xprocess not installed.')
+ pytest.skip("pytest-xprocess not installed.")
port_generator = count(13220)
@@ -40,7 +39,7 @@ port_generator = count(13220)
def _patch_reloader_loop():
def f(x):
- print('reloader loop finished')
+ print("reloader loop finished")
# Need to flush for some reason even though xprocess opens the
# subprocess' stdout in unbuffered mode.
# flush=True makes the test fail on py2, so flush manually
@@ -48,6 +47,7 @@ def _patch_reloader_loop():
return time.sleep(x)
import werkzeug._reloader
+
werkzeug._reloader.ReloaderLoop._sleep = staticmethod(f)
@@ -59,11 +59,12 @@ pid_logger.addHandler(pid_handler)
def _get_pid_middleware(f):
def inner(environ, start_response):
- if environ['PATH_INFO'] == '/_getpid':
- start_response('200 OK', [('Content-Type', 'text/plain')])
+ if environ["PATH_INFO"] == "/_getpid":
+ start_response("200 OK", [("Content-Type", "text/plain")])
pid_logger.info("pid=%s", os.getpid())
return [to_bytes(str(os.getpid()))]
return f(environ, start_response)
+
return inner
@@ -71,6 +72,7 @@ def _dev_server():
_patch_reloader_loop()
sys.path.insert(0, sys.argv[1])
import testsuite_app
+
app = _get_pid_middleware(testsuite_app.app)
serving.run_simple(application=app, **testsuite_app.kwargs)
@@ -90,10 +92,10 @@ class _ServerInfo(object):
@cached_property
def logfile(self):
- return self.xprocess.getinfo('dev_server').logpath.open()
+ return self.xprocess.getinfo("dev_server").logpath.open()
def request_pid(self):
- if self.url.startswith('http+unix://'):
+ if self.url.startswith("http+unix://"):
from requests_unixsocket import get as rget
else:
from requests import get as rget
@@ -101,7 +103,7 @@ class _ServerInfo(object):
for i in range(10):
time.sleep(0.1 * i)
try:
- response = rget(self.url + '/_getpid', verify=False)
+ response = rget(self.url + "/_getpid", verify=False)
self.last_pid = int(response.text)
return self.last_pid
except Exception as e: # urllib also raises socketerrors
@@ -114,51 +116,58 @@ class _ServerInfo(object):
time.sleep(0.1 * i)
new_pid = self.request_pid()
if not new_pid:
- raise RuntimeError('Server is down.')
+ raise RuntimeError("Server is down.")
if new_pid != old_pid:
return
- raise RuntimeError('Server did not reload.')
+ raise RuntimeError("Server did not reload.")
def wait_for_reloader_loop(self):
for i in range(20):
time.sleep(0.1 * i)
line = self.logfile.readline()
- if 'reloader loop finished' in line:
+ if "reloader loop finished" in line:
return
@pytest.fixture
def dev_server(tmpdir, xprocess, request, monkeypatch):
- '''Run werkzeug.serving.run_simple in its own process.
+ """Run werkzeug.serving.run_simple in its own process.
:param application: String for the module that will be created. The module
must have a global ``app`` object, a ``kwargs`` dict is also available
whose values will be passed to ``run_simple``.
- '''
+ """
+
def run_dev_server(application):
- app_pkg = tmpdir.mkdir('testsuite_app')
- appfile = app_pkg.join('__init__.py')
+ app_pkg = tmpdir.mkdir("testsuite_app")
+ appfile = app_pkg.join("__init__.py")
port = next(port_generator)
- appfile.write('\n\n'.join((
- "kwargs = {{'hostname': 'localhost', 'port': {port:d}}}".format(
- port=port),
- textwrap.dedent(application)
- )))
-
- monkeypatch.delitem(sys.modules, 'testsuite_app', raising=False)
+ appfile.write(
+ "\n\n".join(
+ (
+ "kwargs = {{'hostname': 'localhost', 'port': {port:d}}}".format(
+ port=port
+ ),
+ textwrap.dedent(application),
+ )
+ )
+ )
+
+ monkeypatch.delitem(sys.modules, "testsuite_app", raising=False)
monkeypatch.syspath_prepend(str(tmpdir))
import testsuite_app
- hostname = testsuite_app.kwargs['hostname']
- port = testsuite_app.kwargs['port']
- addr = '{}:{}'.format(hostname, port)
-
- if hostname.startswith('unix://'):
- addr = hostname.split('unix://', 1)[1]
- requests_url = 'http+unix://' + url_quote(addr, safe='')
- elif testsuite_app.kwargs.get('ssl_context', None):
- requests_url = 'https://localhost:{0}'.format(port)
+
+ hostname = testsuite_app.kwargs["hostname"]
+ port = testsuite_app.kwargs["port"]
+ addr = "{}:{}".format(hostname, port)
+
+ if hostname.startswith("unix://"):
+ addr = hostname.split("unix://", 1)[1]
+ requests_url = "http+unix://" + url_quote(addr, safe="")
+ elif testsuite_app.kwargs.get("ssl_context", None):
+ requests_url = "https://localhost:{0}".format(port)
else:
- requests_url = 'http://localhost:{0}'.format(port)
+ requests_url = "http://localhost:{0}".format(port)
info = _ServerInfo(xprocess, addr, requests_url, port)
@@ -171,7 +180,7 @@ def dev_server(tmpdir, xprocess, request, monkeypatch):
def pattern(self):
return "pid=%s" % info.request_pid()
- xprocess.ensure('dev_server', Starter, restart=True)
+ xprocess.ensure("dev_server", Starter, restart=True)
@request.addfinalizer
def teardown():
@@ -191,5 +200,5 @@ def dev_server(tmpdir, xprocess, request, monkeypatch):
return run_dev_server
-if __name__ == '__main__':
+if __name__ == "__main__":
_dev_server()
diff --git a/tests/contrib/cache/conftest.py b/tests/contrib/cache/conftest.py
index cb651098..655f0fa6 100644
--- a/tests/contrib/cache/conftest.py
+++ b/tests/contrib/cache/conftest.py
@@ -4,10 +4,7 @@ import pytest
# build the path to the uwsgi marker file
# when running in tox, this will be relative to the tox env
-filename = os.path.join(
- os.environ.get('TOX_ENVTMPDIR', ''),
- 'test_uwsgi_failed'
-)
+filename = os.path.join(os.environ.get("TOX_ENVTMPDIR", ""), "test_uwsgi_failed")
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
@@ -20,9 +17,9 @@ def pytest_runtest_makereport(item, call):
outcome = yield
report = outcome.get_result()
- if item.cls.__name__ != 'TestUWSGICache':
+ if item.cls.__name__ != "TestUWSGICache":
return
if report.failed:
- with open(filename, 'a') as f:
- f.write(item.name + '\n')
+ with open(filename, "a") as f:
+ f.write(item.name + "\n")
diff --git a/tests/contrib/cache/test_cache.py b/tests/contrib/cache/test_cache.py
index 0fcc5e46..c13227e3 100644
--- a/tests/contrib/cache/test_cache.py
+++ b/tests/contrib/cache/test_cache.py
@@ -12,8 +12,6 @@ import errno
import pytest
-pytestmark = pytest.mark.skip('werkzeug.contrib.cache moved to cachelib')
-
from werkzeug._compat import text_type
from werkzeug.contrib import cache
@@ -33,6 +31,8 @@ except ImportError:
except ImportError:
memcache = None
+pytestmark = pytest.mark.skip("werkzeug.contrib.cache moved to cachelib")
+
class CacheTestsBase(object):
_can_use_fast_sleep = True
@@ -41,13 +41,15 @@ class CacheTestsBase(object):
@pytest.fixture
def fast_sleep(self, monkeypatch):
if self._can_use_fast_sleep:
+
def sleep(delta):
orig_time = cache.time
- monkeypatch.setattr(cache, 'time', lambda: orig_time() + delta)
+ monkeypatch.setattr(cache, "time", lambda: orig_time() + delta)
return sleep
else:
import time
+
return time.sleep
@pytest.fixture
@@ -63,13 +65,13 @@ class CacheTestsBase(object):
class GenericCacheTests(CacheTestsBase):
def test_generic_get_dict(self, c):
- assert c.set('a', 'a')
- assert c.set('b', 'b')
- d = c.get_dict('a', 'b')
- assert 'a' in d
- assert 'a' == d['a']
- assert 'b' in d
- assert 'b' == d['b']
+ assert c.set("a", "a")
+ assert c.set("b", "b")
+ d = c.get_dict("a", "b")
+ assert "a" in d
+ assert "a" == d["a"]
+ assert "b" in d
+ assert "b" == d["b"]
def test_generic_set_get(self, c):
for i in range(3):
@@ -80,71 +82,71 @@ class GenericCacheTests(CacheTestsBase):
assert result == i * i, result
def test_generic_get_set(self, c):
- assert c.set('foo', ['bar'])
- assert c.get('foo') == ['bar']
+ assert c.set("foo", ["bar"])
+ assert c.get("foo") == ["bar"]
def test_generic_get_many(self, c):
- assert c.set('foo', ['bar'])
- assert c.set('spam', 'eggs')
- assert c.get_many('foo', 'spam') == [['bar'], 'eggs']
+ assert c.set("foo", ["bar"])
+ assert c.set("spam", "eggs")
+ assert c.get_many("foo", "spam") == [["bar"], "eggs"]
def test_generic_set_many(self, c):
- assert c.set_many({'foo': 'bar', 'spam': ['eggs']})
- assert c.get('foo') == 'bar'
- assert c.get('spam') == ['eggs']
+ assert c.set_many({"foo": "bar", "spam": ["eggs"]})
+ assert c.get("foo") == "bar"
+ assert c.get("spam") == ["eggs"]
def test_generic_add(self, c):
# sanity check that add() works like set()
- assert c.add('foo', 'bar')
- assert c.get('foo') == 'bar'
- assert not c.add('foo', 'qux')
- assert c.get('foo') == 'bar'
+ assert c.add("foo", "bar")
+ assert c.get("foo") == "bar"
+ assert not c.add("foo", "qux")
+ assert c.get("foo") == "bar"
def test_generic_delete(self, c):
- assert c.add('foo', 'bar')
- assert c.get('foo') == 'bar'
- assert c.delete('foo')
- assert c.get('foo') is None
+ assert c.add("foo", "bar")
+ assert c.get("foo") == "bar"
+ assert c.delete("foo")
+ assert c.get("foo") is None
def test_generic_delete_many(self, c):
- assert c.add('foo', 'bar')
- assert c.add('spam', 'eggs')
- assert c.delete_many('foo', 'spam')
- assert c.get('foo') is None
- assert c.get('spam') is None
+ assert c.add("foo", "bar")
+ assert c.add("spam", "eggs")
+ assert c.delete_many("foo", "spam")
+ assert c.get("foo") is None
+ assert c.get("spam") is None
def test_generic_inc_dec(self, c):
- assert c.set('foo', 1)
- assert c.inc('foo') == c.get('foo') == 2
- assert c.dec('foo') == c.get('foo') == 1
- assert c.delete('foo')
+ assert c.set("foo", 1)
+ assert c.inc("foo") == c.get("foo") == 2
+ assert c.dec("foo") == c.get("foo") == 1
+ assert c.delete("foo")
def test_generic_true_false(self, c):
- assert c.set('foo', True)
- assert c.get('foo') in (True, 1)
- assert c.set('bar', False)
- assert c.get('bar') in (False, 0)
+ assert c.set("foo", True)
+ assert c.get("foo") in (True, 1)
+ assert c.set("bar", False)
+ assert c.get("bar") in (False, 0)
def test_generic_timeout(self, c, fast_sleep):
- c.set('foo', 'bar', 0)
- assert c.get('foo') == 'bar'
- c.set('baz', 'qux', 1)
- assert c.get('baz') == 'qux'
+ c.set("foo", "bar", 0)
+ assert c.get("foo") == "bar"
+ c.set("baz", "qux", 1)
+ assert c.get("baz") == "qux"
fast_sleep(3)
# timeout of zero means no timeout
- assert c.get('foo') == 'bar'
+ assert c.get("foo") == "bar"
if self._guaranteed_deletes:
- assert c.get('baz') is None
+ assert c.get("baz") is None
def test_generic_has(self, c):
- assert c.has('foo') in (False, 0)
- assert c.has('spam') in (False, 0)
- assert c.set('foo', 'bar')
- assert c.has('foo') in (True, 1)
- assert c.has('spam') in (False, 0)
- c.delete('foo')
- assert c.has('foo') in (False, 0)
- assert c.has('spam') in (False, 0)
+ assert c.has("foo") in (False, 0)
+ assert c.has("spam") in (False, 0)
+ assert c.set("foo", "bar")
+ assert c.has("foo") in (True, 1)
+ assert c.has("spam") in (False, 0)
+ c.delete("foo")
+ assert c.has("foo") in (False, 0)
+ assert c.has("spam") in (False, 0)
class TestSimpleCache(GenericCacheTests):
@@ -154,10 +156,10 @@ class TestSimpleCache(GenericCacheTests):
def test_purge(self):
c = cache.SimpleCache(threshold=2)
- c.set('a', 'a')
- c.set('b', 'b')
- c.set('c', 'c')
- c.set('d', 'd')
+ c.set("a", "a")
+ c.set("b", "b")
+ c.set("c", "c")
+ c.set("d", "d")
# Cache purges old items *before* it sets new ones.
assert len(c._cache) == 3
@@ -178,7 +180,7 @@ class TestFileSystemCache(GenericCacheTests):
assert nof_cache_files <= THRESHOLD
def test_filesystemcache_clear(self, c):
- assert c.set('foo', 'bar')
+ assert c.set("foo", "bar")
nof_cache_files = c.get(c._fs_count_file)
assert nof_cache_files == 1
assert c.clear()
@@ -202,13 +204,13 @@ class TestFileSystemCache(GenericCacheTests):
assert nof_cache_files is None
def test_count_file_accuracy(self, c):
- assert c.set('foo', 'bar')
- assert c.set('moo', 'car')
- c.add('moo', 'tar')
+ assert c.set("foo", "bar")
+ assert c.set("moo", "car")
+ c.add("moo", "tar")
assert c.get(c._fs_count_file) == 2
- assert c.add('too', 'far')
+ assert c.add("too", "far")
assert c.get(c._fs_count_file) == 3
- assert c.delete('moo')
+ assert c.delete("moo")
assert c.get(c._fs_count_file) == 2
assert c.clear()
assert c.get(c._fs_count_file) == 0
@@ -220,124 +222,121 @@ class TestFileSystemCache(GenericCacheTests):
class TestRedisCache(GenericCacheTests):
_can_use_fast_sleep = False
- @pytest.fixture(scope='class', autouse=True)
+ @pytest.fixture(scope="class", autouse=True)
def requirements(self, xprocess):
if redis is None:
pytest.skip('Python package "redis" is not installed.')
def prepare(cwd):
- return '[Rr]eady to accept connections', ['redis-server']
+ return "[Rr]eady to accept connections", ["redis-server"]
try:
- xprocess.ensure('redis_server', prepare)
+ xprocess.ensure("redis_server", prepare)
except IOError as e:
# xprocess raises FileNotFoundError
if e.errno == errno.ENOENT:
- pytest.skip('Redis is not installed.')
+ pytest.skip("Redis is not installed.")
else:
raise
yield
- xprocess.getinfo('redis_server').terminate()
+ xprocess.getinfo("redis_server").terminate()
@pytest.fixture(params=(None, False, True))
def make_cache(self, request):
if request.param is None:
- host = 'localhost'
+ host = "localhost"
elif request.param:
host = redis.StrictRedis()
else:
host = redis.Redis()
- c = cache.RedisCache(
- host=host,
- key_prefix='werkzeug-test-case:',
- )
+ c = cache.RedisCache(host=host, key_prefix="werkzeug-test-case:")
yield lambda: c
c.clear()
def test_compat(self, c):
- assert c._client.set(c.key_prefix + 'foo', 'Awesome')
- assert c.get('foo') == b'Awesome'
- assert c._client.set(c.key_prefix + 'foo', '42')
- assert c.get('foo') == 42
+ assert c._client.set(c.key_prefix + "foo", "Awesome")
+ assert c.get("foo") == b"Awesome"
+ assert c._client.set(c.key_prefix + "foo", "42")
+ assert c.get("foo") == 42
def test_empty_host(self):
with pytest.raises(ValueError) as exc_info:
cache.RedisCache(host=None)
- assert text_type(exc_info.value) == 'RedisCache host parameter may not be None'
+ assert text_type(exc_info.value) == "RedisCache host parameter may not be None"
class TestMemcachedCache(GenericCacheTests):
_can_use_fast_sleep = False
_guaranteed_deletes = False
- @pytest.fixture(scope='class', autouse=True)
+ @pytest.fixture(scope="class", autouse=True)
def requirements(self, xprocess):
if memcache is None:
pytest.skip(
- 'Python package for memcache is not installed. Need one of '
+ "Python package for memcache is not installed. Need one of "
'"pylibmc", "google.appengine", or "memcache".'
)
def prepare(cwd):
- return '', ['memcached']
+ return "", ["memcached"]
try:
- xprocess.ensure('memcached', prepare)
+ xprocess.ensure("memcached", prepare)
except IOError as e:
# xprocess raises FileNotFoundError
if e.errno == errno.ENOENT:
- pytest.skip('Memcached is not installed.')
+ pytest.skip("Memcached is not installed.")
else:
raise
yield
- xprocess.getinfo('memcached').terminate()
+ xprocess.getinfo("memcached").terminate()
@pytest.fixture
def make_cache(self):
- c = cache.MemcachedCache(key_prefix='werkzeug-test-case:')
+ c = cache.MemcachedCache(key_prefix="werkzeug-test-case:")
yield lambda: c
c.clear()
def test_compat(self, c):
- assert c._client.set(c.key_prefix + 'foo', 'bar')
- assert c.get('foo') == 'bar'
+ assert c._client.set(c.key_prefix + "foo", "bar")
+ assert c.get("foo") == "bar"
def test_huge_timeouts(self, c):
# Timeouts greater than epoch are interpreted as POSIX timestamps
# (i.e. not relative to now, but relative to epoch)
epoch = 2592000
- c.set('foo', 'bar', epoch + 100)
- assert c.get('foo') == 'bar'
+ c.set("foo", "bar", epoch + 100)
+ assert c.get("foo") == "bar"
class TestUWSGICache(GenericCacheTests):
_can_use_fast_sleep = False
_guaranteed_deletes = False
- @pytest.fixture(scope='class', autouse=True)
+ @pytest.fixture(scope="class", autouse=True)
def requirements(self):
try:
import uwsgi # NOQA
except ImportError:
pytest.skip(
'Python "uwsgi" package is only avaialable when running '
- 'inside uWSGI.'
+ "inside uWSGI."
)
@pytest.fixture
def make_cache(self):
- c = cache.UWSGICache(cache='werkzeugtest')
+ c = cache.UWSGICache(cache="werkzeugtest")
yield lambda: c
c.clear()
class TestNullCache(CacheTestsBase):
- @pytest.fixture(scope='class', autouse=True)
+ @pytest.fixture(scope="class", autouse=True)
def make_cache(self):
return cache.NullCache
def test_has(self, c):
- assert not c.has('foo')
+ assert not c.has("foo")
diff --git a/tests/contrib/test_atom.py b/tests/contrib/test_atom.py
index 4117d571..5a10556e 100644
--- a/tests/contrib/test_atom.py
+++ b/tests/contrib/test_atom.py
@@ -9,9 +9,12 @@
:license: BSD-3-Clause
"""
import datetime
+
import pytest
-from werkzeug.contrib.atom import format_iso8601, AtomFeed, FeedEntry
+from werkzeug.contrib.atom import AtomFeed
+from werkzeug.contrib.atom import FeedEntry
+from werkzeug.contrib.atom import format_iso8601
class TestAtomFeed(object):
@@ -25,26 +28,25 @@ class TestAtomFeed(object):
def test_atom_title_no_id(self):
with pytest.raises(ValueError):
- AtomFeed(title='test_title')
+ AtomFeed(title="test_title")
def test_atom_add_one(self):
- a = AtomFeed(title='test_title', id=1)
- f = FeedEntry(
- title='test_title', id=1, updated=datetime.datetime.now())
+ a = AtomFeed(title="test_title", id=1)
+ f = FeedEntry(title="test_title", id=1, updated=datetime.datetime.now())
assert len(a.entries) == 0
a.add(f)
assert len(a.entries) == 1
def test_atom_add_one_kwargs(self):
- a = AtomFeed(title='test_title', id=1)
+ a = AtomFeed(title="test_title", id=1)
assert len(a.entries) == 0
- a.add(title='test_title', id=1, updated=datetime.datetime.now())
+ a.add(title="test_title", id=1, updated=datetime.datetime.now())
assert len(a.entries) == 1
assert isinstance(a.entries[0], FeedEntry)
def test_atom_to_str(self):
updated_time = datetime.datetime.now()
- expected_repr = '''
+ expected_repr = """
<?xml version="1.0" encoding="utf-8"?>
<feed xmlns="http://www.w3.org/2005/Atom">
<title type="text">test_title</title>
@@ -52,10 +54,11 @@ class TestAtomFeed(object):
<updated>%s</updated>
<generator>Werkzeug</generator>
</feed>
- ''' % format_iso8601(updated_time)
- a = AtomFeed(title='test_title', id=1, updated=updated_time)
- assert str(a).strip().replace(' ', '') == \
- expected_repr.strip().replace(' ', '')
+ """ % format_iso8601(
+ updated_time
+ )
+ a = AtomFeed(title="test_title", id=1, updated=updated_time)
+ assert str(a).strip().replace(" ", "") == expected_repr.strip().replace(" ", "")
class TestFeedEntry(object):
@@ -69,35 +72,38 @@ class TestFeedEntry(object):
def test_feed_entry_no_id(self):
with pytest.raises(ValueError):
- FeedEntry(title='test_title')
+ FeedEntry(title="test_title")
def test_feed_entry_no_updated(self):
with pytest.raises(ValueError):
- FeedEntry(title='test_title', id=1)
+ FeedEntry(title="test_title", id=1)
def test_feed_entry_to_str(self):
updated_time = datetime.datetime.now()
- expected_feed_entry_str = '''
+ expected_feed_entry_str = """
<entry>
<title type="text">test_title</title>
<id>1</id>
<updated>%s</updated>
</entry>
- ''' % format_iso8601(updated_time)
+ """ % format_iso8601(
+ updated_time
+ )
- f = FeedEntry(title='test_title', id=1, updated=updated_time)
- assert str(f).strip().replace(' ', '') == \
- expected_feed_entry_str.strip().replace(' ', '')
+ f = FeedEntry(title="test_title", id=1, updated=updated_time)
+ assert str(f).strip().replace(
+ " ", ""
+ ) == expected_feed_entry_str.strip().replace(" ", "")
def test_format_iso8601():
# naive datetime should be treated as utc
dt = datetime.datetime(2014, 8, 31, 2, 5, 6)
- assert format_iso8601(dt) == '2014-08-31T02:05:06Z'
+ assert format_iso8601(dt) == "2014-08-31T02:05:06Z"
# tz-aware datetime
dt = datetime.datetime(2014, 8, 31, 11, 5, 6, tzinfo=KST())
- assert format_iso8601(dt) == '2014-08-31T11:05:06+09:00'
+ assert format_iso8601(dt) == "2014-08-31T11:05:06+09:00"
class KST(datetime.tzinfo):
@@ -108,7 +114,7 @@ class KST(datetime.tzinfo):
return datetime.timedelta(hours=9)
def tzname(self, dt):
- return 'KST'
+ return "KST"
def dst(self, dt):
return datetime.timedelta(0)
diff --git a/tests/contrib/test_fixers.py b/tests/contrib/test_fixers.py
index 79e5cc7f..2777f611 100644
--- a/tests/contrib/test_fixers.py
+++ b/tests/contrib/test_fixers.py
@@ -9,80 +9,78 @@ from werkzeug.wrappers import Response
@Request.application
def path_check_app(request):
- return Response('PATH_INFO: %s\nSCRIPT_NAME: %s' % (
- request.environ.get('PATH_INFO', ''),
- request.environ.get('SCRIPT_NAME', '')
- ))
+ return Response(
+ "PATH_INFO: %s\nSCRIPT_NAME: %s"
+ % (request.environ.get("PATH_INFO", ""), request.environ.get("SCRIPT_NAME", ""))
+ )
class TestServerFixer(object):
-
def test_cgi_root_fix(self):
app = fixers.CGIRootFix(path_check_app)
response = Response.from_app(
- app,
- dict(create_environ(),
- SCRIPT_NAME='/foo',
- PATH_INFO='/bar'))
- assert response.get_data() == b'PATH_INFO: /bar\nSCRIPT_NAME: '
+ app, dict(create_environ(), SCRIPT_NAME="/foo", PATH_INFO="/bar")
+ )
+ assert response.get_data() == b"PATH_INFO: /bar\nSCRIPT_NAME: "
def test_cgi_root_fix_custom_app_root(self):
- app = fixers.CGIRootFix(path_check_app, app_root='/baz/')
+ app = fixers.CGIRootFix(path_check_app, app_root="/baz/")
response = Response.from_app(
- app,
- dict(create_environ(),
- SCRIPT_NAME='/foo',
- PATH_INFO='/bar'))
- assert response.get_data() == b'PATH_INFO: /bar\nSCRIPT_NAME: baz'
+ app, dict(create_environ(), SCRIPT_NAME="/foo", PATH_INFO="/bar")
+ )
+ assert response.get_data() == b"PATH_INFO: /bar\nSCRIPT_NAME: baz"
def test_path_info_from_request_uri_fix(self):
app = fixers.PathInfoFromRequestUriFix(path_check_app)
- for key in 'REQUEST_URI', 'REQUEST_URL', 'UNENCODED_URL':
- env = dict(create_environ(), SCRIPT_NAME='/test', PATH_INFO='/?????')
- env[key] = '/test/foo%25bar?drop=this'
+ for key in "REQUEST_URI", "REQUEST_URL", "UNENCODED_URL":
+ env = dict(create_environ(), SCRIPT_NAME="/test", PATH_INFO="/?????")
+ env[key] = "/test/foo%25bar?drop=this"
response = Response.from_app(app, env)
- assert response.get_data() == b'PATH_INFO: /foo%bar\nSCRIPT_NAME: /test'
+ assert response.get_data() == b"PATH_INFO: /foo%bar\nSCRIPT_NAME: /test"
def test_header_rewriter_fix(self):
@Request.application
def application(request):
- return Response("", headers=[
- ('X-Foo', 'bar')
- ])
- application = fixers.HeaderRewriterFix(application, ('X-Foo',), (('X-Bar', '42'),))
+ return Response("", headers=[("X-Foo", "bar")])
+
+ application = fixers.HeaderRewriterFix(
+ application, ("X-Foo",), (("X-Bar", "42"),)
+ )
response = Response.from_app(application, create_environ())
- assert response.headers['Content-Type'] == 'text/plain; charset=utf-8'
- assert 'X-Foo' not in response.headers
- assert response.headers['X-Bar'] == '42'
+ assert response.headers["Content-Type"] == "text/plain; charset=utf-8"
+ assert "X-Foo" not in response.headers
+ assert response.headers["X-Bar"] == "42"
class TestBrowserFixer(object):
-
def test_ie_fixes(self):
@fixers.InternetExplorerFix
@Request.application
def application(request):
- response = Response('binary data here', mimetype='application/vnd.ms-excel')
- response.headers['Vary'] = 'Cookie'
- response.headers['Content-Disposition'] = 'attachment; filename=foo.xls'
+ response = Response("binary data here", mimetype="application/vnd.ms-excel")
+ response.headers["Vary"] = "Cookie"
+ response.headers["Content-Disposition"] = "attachment; filename=foo.xls"
return response
c = Client(application, Response)
- response = c.get('/', headers=[
- ('User-Agent', 'Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 6.0)')
- ])
+ response = c.get(
+ "/",
+ headers=[
+ ("User-Agent", "Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 6.0)")
+ ],
+ )
# IE gets no vary
- assert response.get_data() == b'binary data here'
- assert 'vary' not in response.headers
- assert response.headers['content-disposition'] == 'attachment; filename=foo.xls'
- assert response.headers['content-type'] == 'application/vnd.ms-excel'
+ assert response.get_data() == b"binary data here"
+ assert "vary" not in response.headers
+ assert response.headers["content-disposition"] == "attachment; filename=foo.xls"
+ assert response.headers["content-type"] == "application/vnd.ms-excel"
# other browsers do
c = Client(application, Response)
- response = c.get('/')
- assert response.get_data() == b'binary data here'
- assert 'vary' in response.headers
+ response = c.get("/")
+ assert response.get_data() == b"binary data here"
+ assert "vary" in response.headers
cc = ResponseCacheControl()
cc.no_cache = True
@@ -90,40 +88,47 @@ class TestBrowserFixer(object):
@fixers.InternetExplorerFix
@Request.application
def application(request):
- response = Response('binary data here', mimetype='application/vnd.ms-excel')
- response.headers['Pragma'] = ', '.join(pragma)
- response.headers['Cache-Control'] = cc.to_header()
- response.headers['Content-Disposition'] = 'attachment; filename=foo.xls'
+ response = Response("binary data here", mimetype="application/vnd.ms-excel")
+ response.headers["Pragma"] = ", ".join(pragma)
+ response.headers["Cache-Control"] = cc.to_header()
+ response.headers["Content-Disposition"] = "attachment; filename=foo.xls"
return response
# IE has no pragma or cache control
- pragma = ('no-cache',)
+ pragma = ("no-cache",)
c = Client(application, Response)
- response = c.get('/', headers=[
- ('User-Agent', 'Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 6.0)')
- ])
- assert response.get_data() == b'binary data here'
- assert 'pragma' not in response.headers
- assert 'cache-control' not in response.headers
- assert response.headers['content-disposition'] == 'attachment; filename=foo.xls'
+ response = c.get(
+ "/",
+ headers=[
+ ("User-Agent", "Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 6.0)")
+ ],
+ )
+ assert response.get_data() == b"binary data here"
+ assert "pragma" not in response.headers
+ assert "cache-control" not in response.headers
+ assert response.headers["content-disposition"] == "attachment; filename=foo.xls"
# IE has simplified pragma
- pragma = ('no-cache', 'x-foo')
+ pragma = ("no-cache", "x-foo")
cc.proxy_revalidate = True
- response = c.get('/', headers=[
- ('User-Agent', 'Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 6.0)')
- ])
- assert response.get_data() == b'binary data here'
- assert response.headers['pragma'] == 'x-foo'
- assert response.headers['cache-control'] == 'proxy-revalidate'
- assert response.headers['content-disposition'] == 'attachment; filename=foo.xls'
+ response = c.get(
+ "/",
+ headers=[
+ ("User-Agent", "Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 6.0)")
+ ],
+ )
+ assert response.get_data() == b"binary data here"
+ assert response.headers["pragma"] == "x-foo"
+ assert response.headers["cache-control"] == "proxy-revalidate"
+ assert response.headers["content-disposition"] == "attachment; filename=foo.xls"
# regular browsers get everything
- response = c.get('/')
- assert response.get_data() == b'binary data here'
- assert response.headers['pragma'] == 'no-cache, x-foo'
- cc = parse_cache_control_header(response.headers['cache-control'],
- cls=ResponseCacheControl)
+ response = c.get("/")
+ assert response.get_data() == b"binary data here"
+ assert response.headers["pragma"] == "no-cache, x-foo"
+ cc = parse_cache_control_header(
+ response.headers["cache-control"], cls=ResponseCacheControl
+ )
assert cc.no_cache
assert cc.proxy_revalidate
- assert response.headers['content-disposition'] == 'attachment; filename=foo.xls'
+ assert response.headers["content-disposition"] == "attachment; filename=foo.xls"
diff --git a/tests/contrib/test_iterio.py b/tests/contrib/test_iterio.py
index 00a22d88..e2c48130 100644
--- a/tests/contrib/test_iterio.py
+++ b/tests/contrib/test_iterio.py
@@ -10,12 +10,12 @@
"""
import pytest
-from tests import strict_eq
-from werkzeug.contrib.iterio import IterIO, greenlet
+from .. import strict_eq
+from werkzeug.contrib.iterio import greenlet
+from werkzeug.contrib.iterio import IterIO
class TestIterO(object):
-
def test_basic_native(self):
io = IterIO(["Hello", "World", "1", "2", "3"])
io.seek(0)
@@ -34,24 +34,24 @@ class TestIterO(object):
assert io.closed
io = IterIO(["Hello\n", "World!"])
- assert io.readline() == 'Hello\n'
- assert io._buf == 'Hello\n'
- assert io.read() == 'World!'
- assert io._buf == 'Hello\nWorld!'
+ assert io.readline() == "Hello\n"
+ assert io._buf == "Hello\n"
+ assert io.read() == "World!"
+ assert io._buf == "Hello\nWorld!"
assert io.tell() == 12
io.seek(0)
- assert io.readlines() == ['Hello\n', 'World!']
+ assert io.readlines() == ["Hello\n", "World!"]
- io = IterIO(['Line one\nLine ', 'two\nLine three'])
- assert list(io) == ['Line one\n', 'Line two\n', 'Line three']
- io = IterIO(iter('Line one\nLine two\nLine three'))
- assert list(io) == ['Line one\n', 'Line two\n', 'Line three']
- io = IterIO(['Line one\nL', 'ine', ' two', '\nLine three'])
- assert list(io) == ['Line one\n', 'Line two\n', 'Line three']
+ io = IterIO(["Line one\nLine ", "two\nLine three"])
+ assert list(io) == ["Line one\n", "Line two\n", "Line three"]
+ io = IterIO(iter("Line one\nLine two\nLine three"))
+ assert list(io) == ["Line one\n", "Line two\n", "Line three"]
+ io = IterIO(["Line one\nL", "ine", " two", "\nLine three"])
+ assert list(io) == ["Line one\n", "Line two\n", "Line three"]
io = IterIO(["foo\n", "bar"])
io.seek(-4, 2)
- assert io.read(4) == '\nbar'
+ assert io.read(4) == "\nbar"
pytest.raises(IOError, io.seek, 2, 100)
io.close()
@@ -74,17 +74,17 @@ class TestIterO(object):
assert io.closed
io = IterIO([b"Hello\n", b"World!"])
- assert io.readline() == b'Hello\n'
- assert io._buf == b'Hello\n'
- assert io.read() == b'World!'
- assert io._buf == b'Hello\nWorld!'
+ assert io.readline() == b"Hello\n"
+ assert io._buf == b"Hello\n"
+ assert io.read() == b"World!"
+ assert io._buf == b"Hello\nWorld!"
assert io.tell() == 12
io.seek(0)
- assert io.readlines() == [b'Hello\n', b'World!']
+ assert io.readlines() == [b"Hello\n", b"World!"]
io = IterIO([b"foo\n", b"bar"])
io.seek(-4, 2)
- assert io.read(4) == b'\nbar'
+ assert io.read(4) == b"\nbar"
pytest.raises(IOError, io.seek, 2, 100)
io.close()
@@ -107,17 +107,17 @@ class TestIterO(object):
assert io.closed
io = IterIO([u"Hello\n", u"World!"])
- assert io.readline() == u'Hello\n'
- assert io._buf == u'Hello\n'
- assert io.read() == u'World!'
- assert io._buf == u'Hello\nWorld!'
+ assert io.readline() == u"Hello\n"
+ assert io._buf == u"Hello\n"
+ assert io.read() == u"World!"
+ assert io._buf == u"Hello\nWorld!"
assert io.tell() == 12
io.seek(0)
- assert io.readlines() == [u'Hello\n', u'World!']
+ assert io.readlines() == [u"Hello\n", u"World!"]
io = IterIO([u"foo\n", u"bar"])
io.seek(-4, 2)
- assert io.read(4) == u'\nbar'
+ assert io.read(4) == u"\nbar"
pytest.raises(IOError, io.seek, 2, 100)
io.close()
@@ -125,60 +125,62 @@ class TestIterO(object):
def test_sentinel_cases(self):
io = IterIO([])
- strict_eq(io.read(), '')
- io = IterIO([], b'')
- strict_eq(io.read(), b'')
- io = IterIO([], u'')
- strict_eq(io.read(), u'')
+ strict_eq(io.read(), "")
+ io = IterIO([], b"")
+ strict_eq(io.read(), b"")
+ io = IterIO([], u"")
+ strict_eq(io.read(), u"")
io = IterIO([])
- strict_eq(io.read(), '')
- io = IterIO([b''])
- strict_eq(io.read(), b'')
- io = IterIO([u''])
- strict_eq(io.read(), u'')
+ strict_eq(io.read(), "")
+ io = IterIO([b""])
+ strict_eq(io.read(), b"")
+ io = IterIO([u""])
+ strict_eq(io.read(), u"")
io = IterIO([])
- strict_eq(io.readline(), '')
- io = IterIO([], b'')
- strict_eq(io.readline(), b'')
- io = IterIO([], u'')
- strict_eq(io.readline(), u'')
+ strict_eq(io.readline(), "")
+ io = IterIO([], b"")
+ strict_eq(io.readline(), b"")
+ io = IterIO([], u"")
+ strict_eq(io.readline(), u"")
io = IterIO([])
- strict_eq(io.readline(), '')
- io = IterIO([b''])
- strict_eq(io.readline(), b'')
- io = IterIO([u''])
- strict_eq(io.readline(), u'')
+ strict_eq(io.readline(), "")
+ io = IterIO([b""])
+ strict_eq(io.readline(), b"")
+ io = IterIO([u""])
+ strict_eq(io.readline(), u"")
-@pytest.mark.skipif(greenlet is None, reason='Greenlet is not installed.')
+@pytest.mark.skipif(greenlet is None, reason="Greenlet is not installed.")
class TestIterI(object):
-
def test_basic(self):
def producer(out):
- out.write('1\n')
- out.write('2\n')
+ out.write("1\n")
+ out.write("2\n")
out.flush()
- out.write('3\n')
+ out.write("3\n")
+
iterable = IterIO(producer)
- assert next(iterable) == '1\n2\n'
- assert next(iterable) == '3\n'
+ assert next(iterable) == "1\n2\n"
+ assert next(iterable) == "3\n"
pytest.raises(StopIteration, next, iterable)
def test_sentinel_cases(self):
def producer_dummy_flush(out):
out.flush()
+
iterable = IterIO(producer_dummy_flush)
- strict_eq(next(iterable), '')
+ strict_eq(next(iterable), "")
def producer_empty(out):
pass
+
iterable = IterIO(producer_empty)
pytest.raises(StopIteration, next, iterable)
- iterable = IterIO(producer_dummy_flush, b'')
- strict_eq(next(iterable), b'')
- iterable = IterIO(producer_dummy_flush, u'')
- strict_eq(next(iterable), u'')
+ iterable = IterIO(producer_dummy_flush, b"")
+ strict_eq(next(iterable), b"")
+ iterable = IterIO(producer_dummy_flush, u"")
+ strict_eq(next(iterable), u"")
diff --git a/tests/contrib/test_securecookie.py b/tests/contrib/test_securecookie.py
index 80e82147..7231ac88 100644
--- a/tests/contrib/test_securecookie.py
+++ b/tests/contrib/test_securecookie.py
@@ -9,54 +9,59 @@
:license: BSD-3-Clause
"""
import json
+
import pytest
from werkzeug._compat import to_native
-from werkzeug.utils import parse_cookie
-from werkzeug.wrappers import Request, Response
from werkzeug.contrib.securecookie import SecureCookie
+from werkzeug.utils import parse_cookie
+from werkzeug.wrappers import Request
+from werkzeug.wrappers import Response
def test_basic_support():
- c = SecureCookie(secret_key=b'foo')
+ c = SecureCookie(secret_key=b"foo")
assert c.new
assert not c.modified
assert not c.should_save
- c['x'] = 42
+ c["x"] = 42
assert c.modified
assert c.should_save
s = c.serialize()
- c2 = SecureCookie.unserialize(s, b'foo')
+ c2 = SecureCookie.unserialize(s, b"foo")
assert c is not c2
assert not c2.new
assert not c2.modified
assert not c2.should_save
assert c2 == c
- c3 = SecureCookie.unserialize(s, b'wrong foo')
+ c3 = SecureCookie.unserialize(s, b"wrong foo")
assert not c3.modified
assert not c3.new
assert c3 == {}
- c4 = SecureCookie({'x': 42}, 'foo')
+ c4 = SecureCookie({"x": 42}, "foo")
c4_serialized = c4.serialize()
- assert SecureCookie.unserialize(c4_serialized, 'foo') == c4
+ assert SecureCookie.unserialize(c4_serialized, "foo") == c4
def test_wrapper_support():
req = Request.from_values()
resp = Response()
- c = SecureCookie.load_cookie(req, secret_key=b'foo')
+ c = SecureCookie.load_cookie(req, secret_key=b"foo")
assert c.new
- c['foo'] = 42
- assert c.secret_key == b'foo'
+ c["foo"] = 42
+ assert c.secret_key == b"foo"
c.save_cookie(resp)
- req = Request.from_values(headers={
- 'Cookie': 'session="%s"' % parse_cookie(resp.headers['set-cookie'])['session']
- })
- c2 = SecureCookie.load_cookie(req, secret_key=b'foo')
+ req = Request.from_values(
+ headers={
+ "Cookie": 'session="%s"'
+ % parse_cookie(resp.headers["set-cookie"])["session"]
+ }
+ )
+ c2 = SecureCookie.load_cookie(req, secret_key=b"foo")
assert not c2.new
assert c2 == c
diff --git a/tests/contrib/test_sessions.py b/tests/contrib/test_sessions.py
index a941652f..cab0ae56 100644
--- a/tests/contrib/test_sessions.py
+++ b/tests/contrib/test_sessions.py
@@ -24,7 +24,7 @@ def test_basic_fs_sessions(tmpdir):
x = store.new()
assert x.new
assert not x.modified
- x['foo'] = [1, 2, 3]
+ x["foo"] = [1, 2, 3]
assert x.modified
store.save(x)
@@ -33,7 +33,7 @@ def test_basic_fs_sessions(tmpdir):
assert not x2.modified
assert x2 is not x
assert x2 == x
- x2['test'] = 3
+ x2["test"] = 3
assert x2.modified
assert not x2.new
store.save(x2)
@@ -67,7 +67,7 @@ def test_renewing_fs_session(tmpdir):
def test_fs_session_lising(tmpdir):
store = FilesystemSessionStore(str(tmpdir), renew_missing=True)
sessions = set()
- for x in range(10):
+ for _ in range(10):
sess = store.new()
store.save(sess)
sessions.add(sess.sid)
diff --git a/tests/contrib/test_wrappers.py b/tests/contrib/test_wrappers.py
index 7bc2e6c8..fb49337f 100644
--- a/tests/contrib/test_wrappers.py
+++ b/tests/contrib/test_wrappers.py
@@ -8,78 +8,81 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-
-from werkzeug.contrib import wrappers
from werkzeug import routing
-from werkzeug.wrappers import Request, Response
+from werkzeug.contrib import wrappers
+from werkzeug.wrappers import Request
+from werkzeug.wrappers import Response
def test_reverse_slash_behavior():
class MyRequest(wrappers.ReverseSlashBehaviorRequestMixin, Request):
pass
- req = MyRequest.from_values('/foo/bar', 'http://example.com/test')
- assert req.url == 'http://example.com/test/foo/bar'
- assert req.path == 'foo/bar'
- assert req.script_root == '/test/'
+
+ req = MyRequest.from_values("/foo/bar", "http://example.com/test")
+ assert req.url == "http://example.com/test/foo/bar"
+ assert req.path == "foo/bar"
+ assert req.script_root == "/test/"
# make sure the routing system works with the slashes in
# reverse order as well.
- map = routing.Map([routing.Rule('/foo/bar', endpoint='foo')])
+ map = routing.Map([routing.Rule("/foo/bar", endpoint="foo")])
adapter = map.bind_to_environ(req.environ)
- assert adapter.match() == ('foo', {})
+ assert adapter.match() == ("foo", {})
adapter = map.bind(req.host, req.script_root)
- assert adapter.match(req.path) == ('foo', {})
+ assert adapter.match(req.path) == ("foo", {})
def test_dynamic_charset_request_mixin():
class MyRequest(wrappers.DynamicCharsetRequestMixin, Request):
pass
- env = {'CONTENT_TYPE': 'text/html'}
+
+ env = {"CONTENT_TYPE": "text/html"}
req = MyRequest(env)
- assert req.charset == 'latin1'
+ assert req.charset == "latin1"
- env = {'CONTENT_TYPE': 'text/html; charset=utf-8'}
+ env = {"CONTENT_TYPE": "text/html; charset=utf-8"}
req = MyRequest(env)
- assert req.charset == 'utf-8'
+ assert req.charset == "utf-8"
- env = {'CONTENT_TYPE': 'application/octet-stream'}
+ env = {"CONTENT_TYPE": "application/octet-stream"}
req = MyRequest(env)
- assert req.charset == 'latin1'
- assert req.url_charset == 'latin1'
+ assert req.charset == "latin1"
+ assert req.url_charset == "latin1"
- MyRequest.url_charset = 'utf-8'
- env = {'CONTENT_TYPE': 'application/octet-stream'}
+ MyRequest.url_charset = "utf-8"
+ env = {"CONTENT_TYPE": "application/octet-stream"}
req = MyRequest(env)
- assert req.charset == 'latin1'
- assert req.url_charset == 'utf-8'
+ assert req.charset == "latin1"
+ assert req.url_charset == "utf-8"
def return_ascii(x):
return "ascii"
- env = {'CONTENT_TYPE': 'text/plain; charset=x-weird-charset'}
+
+ env = {"CONTENT_TYPE": "text/plain; charset=x-weird-charset"}
req = MyRequest(env)
req.unknown_charset = return_ascii
- assert req.charset == 'ascii'
- assert req.url_charset == 'utf-8'
+ assert req.charset == "ascii"
+ assert req.url_charset == "utf-8"
def test_dynamic_charset_response_mixin():
class MyResponse(wrappers.DynamicCharsetResponseMixin, Response):
- default_charset = 'utf-7'
- resp = MyResponse(mimetype='text/html')
- assert resp.charset == 'utf-7'
- resp.charset = 'utf-8'
- assert resp.charset == 'utf-8'
- assert resp.mimetype == 'text/html'
- assert resp.mimetype_params == {'charset': 'utf-8'}
- resp.mimetype_params['charset'] = 'iso-8859-15'
- assert resp.charset == 'iso-8859-15'
- resp.set_data(u'Hällo Wörld')
- assert b''.join(resp.iter_encoded()) == \
- u'Hällo Wörld'.encode('iso-8859-15')
- del resp.headers['content-type']
+ default_charset = "utf-7"
+
+ resp = MyResponse(mimetype="text/html")
+ assert resp.charset == "utf-7"
+ resp.charset = "utf-8"
+ assert resp.charset == "utf-8"
+ assert resp.mimetype == "text/html"
+ assert resp.mimetype_params == {"charset": "utf-8"}
+ resp.mimetype_params["charset"] = "iso-8859-15"
+ assert resp.charset == "iso-8859-15"
+ resp.set_data(u"Hällo Wörld")
+ assert b"".join(resp.iter_encoded()) == u"Hällo Wörld".encode("iso-8859-15")
+ del resp.headers["content-type"]
try:
- resp.charset = 'utf-8'
+ resp.charset = "utf-8"
except TypeError:
pass
else:
- assert False, 'expected type error on charset setting without ct'
+ assert False, "expected type error on charset setting without ct"
diff --git a/tests/hypothesis/test_urls.py b/tests/hypothesis/test_urls.py
index c610714d..61829b3c 100644
--- a/tests/hypothesis/test_urls.py
+++ b/tests/hypothesis/test_urls.py
@@ -1,5 +1,8 @@
import hypothesis
-from hypothesis.strategies import text, dictionaries, lists, integers
+from hypothesis.strategies import dictionaries
+from hypothesis.strategies import integers
+from hypothesis.strategies import lists
+from hypothesis.strategies import text
from werkzeug import urls
from werkzeug.datastructures import OrderedMultiDict
@@ -22,8 +25,9 @@ def test_url_encoding_dict_str_list(d):
@hypothesis.given(dictionaries(text(), integers()))
def test_url_encoding_dict_str_int(d):
- assert OrderedMultiDict({k: str(v) for k, v in d.items()}) == \
- urls.url_decode(urls.url_encode(d))
+ assert OrderedMultiDict({k: str(v) for k, v in d.items()}) == urls.url_decode(
+ urls.url_encode(d)
+ )
@hypothesis.given(text(), text())
diff --git a/tests/multipart/firefox3-2png1txt/request.txt b/tests/multipart/firefox3-2png1txt/request.http
index 721e04e3..721e04e3 100644
--- a/tests/multipart/firefox3-2png1txt/request.txt
+++ b/tests/multipart/firefox3-2png1txt/request.http
Binary files differ
diff --git a/tests/multipart/firefox3-2png1txt/text.txt b/tests/multipart/firefox3-2png1txt/text.txt
index c87634d5..4491a1e4 100644
--- a/tests/multipart/firefox3-2png1txt/text.txt
+++ b/tests/multipart/firefox3-2png1txt/text.txt
@@ -1 +1 @@
-example text \ No newline at end of file
+example text
diff --git a/tests/multipart/firefox3-2pnglongtext/request.txt b/tests/multipart/firefox3-2pnglongtext/request.http
index 489290b6..489290b6 100644
--- a/tests/multipart/firefox3-2pnglongtext/request.txt
+++ b/tests/multipart/firefox3-2pnglongtext/request.http
Binary files differ
diff --git a/tests/multipart/firefox3-2pnglongtext/text.txt b/tests/multipart/firefox3-2pnglongtext/text.txt
index 3bf804d5..833ab62e 100644
--- a/tests/multipart/firefox3-2pnglongtext/text.txt
+++ b/tests/multipart/firefox3-2pnglongtext/text.txt
@@ -1,3 +1,3 @@
--long text
--with boundary
---lookalikes-- \ No newline at end of file
+--lookalikes--
diff --git a/tests/multipart/ie6-2png1txt/request.txt b/tests/multipart/ie6-2png1txt/request.http
index 59fdeae2..59fdeae2 100644
--- a/tests/multipart/ie6-2png1txt/request.txt
+++ b/tests/multipart/ie6-2png1txt/request.http
Binary files differ
diff --git a/tests/multipart/ie6-2png1txt/text.txt b/tests/multipart/ie6-2png1txt/text.txt
index 7c465b7a..a32e65e6 100644
--- a/tests/multipart/ie6-2png1txt/text.txt
+++ b/tests/multipart/ie6-2png1txt/text.txt
@@ -1 +1 @@
-ie6 sucks :-/ \ No newline at end of file
+ie6 sucks :-/
diff --git a/tests/multipart/ie7_full_path_request.txt b/tests/multipart/ie7_full_path_request.http
index acc4e2e1..acc4e2e1 100644
--- a/tests/multipart/ie7_full_path_request.txt
+++ b/tests/multipart/ie7_full_path_request.http
Binary files differ
diff --git a/tests/multipart/opera8-2png1txt/request.txt b/tests/multipart/opera8-2png1txt/request.http
index 8f325914..8f325914 100644
--- a/tests/multipart/opera8-2png1txt/request.txt
+++ b/tests/multipart/opera8-2png1txt/request.http
Binary files differ
diff --git a/tests/multipart/opera8-2png1txt/text.txt b/tests/multipart/opera8-2png1txt/text.txt
index ca01cb00..ea10aa51 100644
--- a/tests/multipart/opera8-2png1txt/text.txt
+++ b/tests/multipart/opera8-2png1txt/text.txt
@@ -1 +1 @@
-blafasel öäü \ No newline at end of file
+blafasel öäü
diff --git a/tests/multipart/test_collect.py b/tests/multipart/test_collect.py
index 395d8cb0..01b89bbf 100644
--- a/tests/multipart/test_collect.py
+++ b/tests/multipart/test_collect.py
@@ -3,55 +3,58 @@
Hacky helper application to collect form data.
"""
from werkzeug.serving import run_simple
-from werkzeug.wrappers import Request, Response
+from werkzeug.wrappers import Request
+from werkzeug.wrappers import Response
def copy_stream(request):
from os import mkdir
from time import time
- folder = 'request-%d' % time()
+
+ folder = "request-%d" % time()
mkdir(folder)
environ = request.environ
- f = open(folder + '/request.txt', 'wb+')
- f.write(environ['wsgi.input'].read(int(environ['CONTENT_LENGTH'])))
+ f = open(folder + "/request.http", "wb+")
+ f.write(environ["wsgi.input"].read(int(environ["CONTENT_LENGTH"])))
f.flush()
f.seek(0)
- environ['wsgi.input'] = f
+ environ["wsgi.input"] = f
request.stat_folder = folder
def stats(request):
copy_stream(request)
- f1 = request.files['file1']
- f2 = request.files['file2']
- text = request.form['text']
- f1.save(request.stat_folder + '/file1.bin')
- f2.save(request.stat_folder + '/file2.bin')
- with open(request.stat_folder + '/text.txt', 'w') as f:
- f.write(text.encode('utf-8'))
- return Response('Done.')
+ f1 = request.files["file1"]
+ f2 = request.files["file2"]
+ text = request.form["text"]
+ f1.save(request.stat_folder + "/file1.bin")
+ f2.save(request.stat_folder + "/file2.bin")
+ with open(request.stat_folder + "/text.txt", "w") as f:
+ f.write(text.encode("utf-8"))
+ return Response("Done.")
def upload_file(request):
- return Response('''
- <h1>Upload File</h1>
- <form action="" method="post" enctype="multipart/form-data">
- <input type="file" name="file1"><br>
- <input type="file" name="file2"><br>
- <textarea name="text"></textarea><br>
- <input type="submit" value="Send">
- </form>
- ''', mimetype='text/html')
+ return Response(
+ """<h1>Upload File</h1>
+ <form action="" method="post" enctype="multipart/form-data">
+ <input type="file" name="file1"><br>
+ <input type="file" name="file2"><br>
+ <textarea name="text"></textarea><br>
+ <input type="submit" value="Send">
+ </form>""",
+ mimetype="text/html",
+ )
def application(environ, start_responseonse):
request = Request(environ)
- if request.method == 'POST':
+ if request.method == "POST":
response = stats(request)
else:
response = upload_file(request)
return response(environ, start_responseonse)
-if __name__ == '__main__':
- run_simple('localhost', 5000, application, use_debugger=True)
+if __name__ == "__main__":
+ run_simple("localhost", 5000, application, use_debugger=True)
diff --git a/tests/multipart/webkit3-2png1txt/request.txt b/tests/multipart/webkit3-2png1txt/request.http
index b4ce0eef..b4ce0eef 100644
--- a/tests/multipart/webkit3-2png1txt/request.txt
+++ b/tests/multipart/webkit3-2png1txt/request.http
Binary files differ
diff --git a/tests/multipart/webkit3-2png1txt/text.txt b/tests/multipart/webkit3-2png1txt/text.txt
index baa13008..17537909 100644
--- a/tests/multipart/webkit3-2png1txt/text.txt
+++ b/tests/multipart/webkit3-2png1txt/text.txt
@@ -1 +1 @@
-this is another text with ümläüts \ No newline at end of file
+this is another text with ümläüts
diff --git a/tests/res/chunked.txt b/tests/res/chunked.http
index 44aa71f6..44aa71f6 100644
--- a/tests/res/chunked.txt
+++ b/tests/res/chunked.http
diff --git a/tests/test_compat.py b/tests/test_compat.py
index ad8e8a4d..98851ba2 100644
--- a/tests/test_compat.py
+++ b/tests/test_compat.py
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
+# flake8: noqa
"""
tests.compat
~~~~~~~~~~~~
@@ -8,25 +9,32 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-
-# This file shouldn't be linted:
-# flake8: noqa
-
-import warnings
-
-from werkzeug.wrappers import Response
from werkzeug.test import create_environ
+from werkzeug.wrappers import Response
def test_old_imports():
- from werkzeug.utils import Headers, MultiDict, CombinedMultiDict, \
- Headers, EnvironHeaders
- from werkzeug.http import Accept, MIMEAccept, CharsetAccept, \
- LanguageAccept, ETags, HeaderSet, WWWAuthenticate, \
- Authorization
+ from werkzeug.utils import (
+ Headers,
+ MultiDict,
+ CombinedMultiDict,
+ Headers,
+ EnvironHeaders,
+ )
+ from werkzeug.http import (
+ Accept,
+ MIMEAccept,
+ CharsetAccept,
+ LanguageAccept,
+ ETags,
+ HeaderSet,
+ WWWAuthenticate,
+ Authorization,
+ )
def test_exposed_werkzeug_mod():
import werkzeug
+
for key in werkzeug.__all__:
getattr(werkzeug, key)
diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py
index 016e0544..20aabf62 100644
--- a/tests/test_datastructures.py
+++ b/tests/test_datastructures.py
@@ -18,43 +18,46 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-
import io
-import pytest
-import tempfile
-
-from tests import strict_eq
-
-
import pickle
+import tempfile
from contextlib import contextmanager
-from copy import copy, deepcopy
+from copy import copy
+from copy import deepcopy
+
+import pytest
-from werkzeug import datastructures, http
-from werkzeug._compat import iterkeys, itervalues, iteritems, iterlists, \
- iterlistvalues, text_type, PY2
+from . import strict_eq
+from werkzeug import datastructures
+from werkzeug import http
+from werkzeug._compat import iteritems
+from werkzeug._compat import iterkeys
+from werkzeug._compat import iterlists
+from werkzeug._compat import iterlistvalues
+from werkzeug._compat import itervalues
+from werkzeug._compat import PY2
+from werkzeug._compat import text_type
from werkzeug.datastructures import Range
from werkzeug.exceptions import BadRequestKeyError
class TestNativeItermethods(object):
-
def test_basic(self):
- @datastructures.native_itermethods(['keys', 'values', 'items'])
+ @datastructures.native_itermethods(["keys", "values", "items"])
class StupidDict(object):
-
def keys(self, multi=1):
- return iter(['a', 'b', 'c'] * multi)
+ return iter(["a", "b", "c"] * multi)
def values(self, multi=1):
return iter([1, 2, 3] * multi)
def items(self, multi=1):
- return iter(zip(iterkeys(self, multi=multi),
- itervalues(self, multi=multi)))
+ return iter(
+ zip(iterkeys(self, multi=multi), itervalues(self, multi=multi))
+ )
d = StupidDict()
- expected_keys = ['a', 'b', 'c']
+ expected_keys = ["a", "b", "c"]
expected_values = [1, 2, 3]
expected_items = list(zip(expected_keys, expected_values))
@@ -81,8 +84,8 @@ class _MutableMultiDictTests(object):
cls.__module__ = module
d = cls()
cls.__module__ = old
- d.setlist(b'foo', [1, 2, 3, 4])
- d.setlist(b'bar', b'foo bar baz'.split())
+ d.setlist(b"foo", [1, 2, 3, 4])
+ d.setlist(b"bar", b"foo bar baz".split())
return d
for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
@@ -91,157 +94,173 @@ class _MutableMultiDictTests(object):
ud = pickle.loads(s)
assert type(ud) == type(d)
assert ud == d
- alternative = pickle.dumps(create_instance('werkzeug'), protocol)
+ alternative = pickle.dumps(create_instance("werkzeug"), protocol)
assert pickle.loads(alternative) == d
- ud[b'newkey'] = b'bla'
+ ud[b"newkey"] = b"bla"
assert ud != d
def test_basic_interface(self):
md = self.storage_class()
assert isinstance(md, dict)
- mapping = [('a', 1), ('b', 2), ('a', 2), ('d', 3),
- ('a', 1), ('a', 3), ('d', 4), ('c', 3)]
+ mapping = [
+ ("a", 1),
+ ("b", 2),
+ ("a", 2),
+ ("d", 3),
+ ("a", 1),
+ ("a", 3),
+ ("d", 4),
+ ("c", 3),
+ ]
md = self.storage_class(mapping)
# simple getitem gives the first value
- assert md['a'] == 1
- assert md['c'] == 3
+ assert md["a"] == 1
+ assert md["c"] == 3
with pytest.raises(KeyError):
- md['e']
- assert md.get('a') == 1
+ md["e"]
+ assert md.get("a") == 1
# list getitem
- assert md.getlist('a') == [1, 2, 1, 3]
- assert md.getlist('d') == [3, 4]
+ assert md.getlist("a") == [1, 2, 1, 3]
+ assert md.getlist("d") == [3, 4]
# do not raise if key not found
- assert md.getlist('x') == []
+ assert md.getlist("x") == []
# simple setitem overwrites all values
- md['a'] = 42
- assert md.getlist('a') == [42]
+ md["a"] = 42
+ assert md.getlist("a") == [42]
# list setitem
- md.setlist('a', [1, 2, 3])
- assert md['a'] == 1
- assert md.getlist('a') == [1, 2, 3]
+ md.setlist("a", [1, 2, 3])
+ assert md["a"] == 1
+ assert md.getlist("a") == [1, 2, 3]
# verify that it does not change original lists
l1 = [1, 2, 3]
- md.setlist('a', l1)
+ md.setlist("a", l1)
del l1[:]
- assert md['a'] == 1
+ assert md["a"] == 1
# setdefault, setlistdefault
- assert md.setdefault('u', 23) == 23
- assert md.getlist('u') == [23]
- del md['u']
+ assert md.setdefault("u", 23) == 23
+ assert md.getlist("u") == [23]
+ del md["u"]
- md.setlist('u', [-1, -2])
+ md.setlist("u", [-1, -2])
# delitem
- del md['u']
+ del md["u"]
with pytest.raises(KeyError):
- md['u']
- del md['d']
- assert md.getlist('d') == []
+ md["u"]
+ del md["d"]
+ assert md.getlist("d") == []
# keys, values, items, lists
- assert list(sorted(md.keys())) == ['a', 'b', 'c']
- assert list(sorted(iterkeys(md))) == ['a', 'b', 'c']
+ assert list(sorted(md.keys())) == ["a", "b", "c"]
+ assert list(sorted(iterkeys(md))) == ["a", "b", "c"]
assert list(sorted(itervalues(md))) == [1, 2, 3]
assert list(sorted(itervalues(md))) == [1, 2, 3]
- assert list(sorted(md.items())) == [('a', 1), ('b', 2), ('c', 3)]
- assert list(sorted(md.items(multi=True))) == \
- [('a', 1), ('a', 2), ('a', 3), ('b', 2), ('c', 3)]
- assert list(sorted(iteritems(md))) == [('a', 1), ('b', 2), ('c', 3)]
- assert list(sorted(iteritems(md, multi=True))) == \
- [('a', 1), ('a', 2), ('a', 3), ('b', 2), ('c', 3)]
+ assert list(sorted(md.items())) == [("a", 1), ("b", 2), ("c", 3)]
+ assert list(sorted(md.items(multi=True))) == [
+ ("a", 1),
+ ("a", 2),
+ ("a", 3),
+ ("b", 2),
+ ("c", 3),
+ ]
+ assert list(sorted(iteritems(md))) == [("a", 1), ("b", 2), ("c", 3)]
+ assert list(sorted(iteritems(md, multi=True))) == [
+ ("a", 1),
+ ("a", 2),
+ ("a", 3),
+ ("b", 2),
+ ("c", 3),
+ ]
- assert list(sorted(md.lists())) == \
- [('a', [1, 2, 3]), ('b', [2]), ('c', [3])]
- assert list(sorted(iterlists(md))) == \
- [('a', [1, 2, 3]), ('b', [2]), ('c', [3])]
+ assert list(sorted(md.lists())) == [("a", [1, 2, 3]), ("b", [2]), ("c", [3])]
+ assert list(sorted(iterlists(md))) == [("a", [1, 2, 3]), ("b", [2]), ("c", [3])]
# copy method
c = md.copy()
- assert c['a'] == 1
- assert c.getlist('a') == [1, 2, 3]
+ assert c["a"] == 1
+ assert c.getlist("a") == [1, 2, 3]
# copy method 2
c = copy(md)
- assert c['a'] == 1
- assert c.getlist('a') == [1, 2, 3]
+ assert c["a"] == 1
+ assert c.getlist("a") == [1, 2, 3]
# deepcopy method
c = md.deepcopy()
- assert c['a'] == 1
- assert c.getlist('a') == [1, 2, 3]
+ assert c["a"] == 1
+ assert c.getlist("a") == [1, 2, 3]
# deepcopy method 2
c = deepcopy(md)
- assert c['a'] == 1
- assert c.getlist('a') == [1, 2, 3]
+ assert c["a"] == 1
+ assert c.getlist("a") == [1, 2, 3]
# update with a multidict
- od = self.storage_class([('a', 4), ('a', 5), ('y', 0)])
+ od = self.storage_class([("a", 4), ("a", 5), ("y", 0)])
md.update(od)
- assert md.getlist('a') == [1, 2, 3, 4, 5]
- assert md.getlist('y') == [0]
+ assert md.getlist("a") == [1, 2, 3, 4, 5]
+ assert md.getlist("y") == [0]
# update with a regular dict
md = c
- od = {'a': 4, 'y': 0}
+ od = {"a": 4, "y": 0}
md.update(od)
- assert md.getlist('a') == [1, 2, 3, 4]
- assert md.getlist('y') == [0]
+ assert md.getlist("a") == [1, 2, 3, 4]
+ assert md.getlist("y") == [0]
# pop, poplist, popitem, popitemlist
- assert md.pop('y') == 0
- assert 'y' not in md
- assert md.poplist('a') == [1, 2, 3, 4]
- assert 'a' not in md
- assert md.poplist('missing') == []
+ assert md.pop("y") == 0
+ assert "y" not in md
+ assert md.poplist("a") == [1, 2, 3, 4]
+ assert "a" not in md
+ assert md.poplist("missing") == []
# remaining: b=2, c=3
popped = md.popitem()
- assert popped in [('b', 2), ('c', 3)]
+ assert popped in [("b", 2), ("c", 3)]
popped = md.popitemlist()
- assert popped in [('b', [2]), ('c', [3])]
+ assert popped in [("b", [2]), ("c", [3])]
# type conversion
- md = self.storage_class({'a': '4', 'b': ['2', '3']})
- assert md.get('a', type=int) == 4
- assert md.getlist('b', type=int) == [2, 3]
+ md = self.storage_class({"a": "4", "b": ["2", "3"]})
+ assert md.get("a", type=int) == 4
+ assert md.getlist("b", type=int) == [2, 3]
# repr
- md = self.storage_class([('a', 1), ('a', 2), ('b', 3)])
+ md = self.storage_class([("a", 1), ("a", 2), ("b", 3)])
assert "('a', 1)" in repr(md)
assert "('a', 2)" in repr(md)
assert "('b', 3)" in repr(md)
# add and getlist
- md.add('c', '42')
- md.add('c', '23')
- assert md.getlist('c') == ['42', '23']
- md.add('c', 'blah')
- assert md.getlist('c', type=int) == [42, 23]
+ md.add("c", "42")
+ md.add("c", "23")
+ assert md.getlist("c") == ["42", "23"]
+ md.add("c", "blah")
+ assert md.getlist("c", type=int) == [42, 23]
# setdefault
md = self.storage_class()
- md.setdefault('x', []).append(42)
- md.setdefault('x', []).append(23)
- assert md['x'] == [42, 23]
+ md.setdefault("x", []).append(42)
+ md.setdefault("x", []).append(23)
+ assert md["x"] == [42, 23]
# to dict
md = self.storage_class()
- md['foo'] = 42
- md.add('bar', 1)
- md.add('bar', 2)
- assert md.to_dict() == {'foo': 42, 'bar': 1}
- assert md.to_dict(flat=False) == {'foo': [42], 'bar': [1, 2]}
+ md["foo"] = 42
+ md.add("bar", 1)
+ md.add("bar", 2)
+ assert md.to_dict() == {"foo": 42, "bar": 1}
+ assert md.to_dict(flat=False) == {"foo": [42], "bar": [1, 2]}
# popitem from empty dict
with pytest.raises(KeyError):
@@ -256,9 +275,9 @@ class _MutableMultiDictTests(object):
# setlist works
md = self.storage_class()
- md['foo'] = 42
- md.setlist('foo', [1, 2])
- assert md.getlist('foo') == [1, 2]
+ md["foo"] = 42
+ md.setlist("foo", [1, 2])
+ assert md.getlist("foo") == [1, 2]
class _ImmutableDictTests(object):
@@ -267,33 +286,33 @@ class _ImmutableDictTests(object):
def test_follows_dict_interface(self):
cls = self.storage_class
- data = {'foo': 1, 'bar': 2, 'baz': 3}
+ data = {"foo": 1, "bar": 2, "baz": 3}
d = cls(data)
- assert d['foo'] == 1
- assert d['bar'] == 2
- assert d['baz'] == 3
- assert sorted(d.keys()) == ['bar', 'baz', 'foo']
- assert 'foo' in d
- assert 'foox' not in d
+ assert d["foo"] == 1
+ assert d["bar"] == 2
+ assert d["baz"] == 3
+ assert sorted(d.keys()) == ["bar", "baz", "foo"]
+ assert "foo" in d
+ assert "foox" not in d
assert len(d) == 3
def test_copies_are_mutable(self):
cls = self.storage_class
- immutable = cls({'a': 1})
+ immutable = cls({"a": 1})
with pytest.raises(TypeError):
- immutable.pop('a')
+ immutable.pop("a")
mutable = immutable.copy()
- mutable.pop('a')
- assert 'a' in immutable
+ mutable.pop("a")
+ assert "a" in immutable
assert mutable is not immutable
assert copy(immutable) is immutable
def test_dict_is_hashable(self):
cls = self.storage_class
- immutable = cls({'a': 1, 'b': 2})
- immutable2 = cls({'a': 2, 'b': 2})
+ immutable = cls({"a": 1, "b": 2})
+ immutable2 = cls({"a": 2, "b": 2})
x = set([immutable])
assert immutable in x
assert immutable2 not in x
@@ -317,8 +336,8 @@ class TestImmutableMultiDict(_ImmutableDictTests):
def test_multidict_is_hashable(self):
cls = self.storage_class
- immutable = cls({'a': [1, 2], 'b': 2})
- immutable2 = cls({'a': [1], 'b': 2})
+ immutable = cls({"a": [1, 2], "b": 2})
+ immutable2 = cls({"a": [1], "b": 2})
x = set([immutable])
assert immutable in x
assert immutable2 not in x
@@ -341,8 +360,8 @@ class TestImmutableOrderedMultiDict(_ImmutableDictTests):
storage_class = datastructures.ImmutableOrderedMultiDict
def test_ordered_multidict_is_hashable(self):
- a = self.storage_class([('a', 1), ('b', 1), ('a', 2)])
- b = self.storage_class([('a', 1), ('a', 2), ('b', 1)])
+ a = self.storage_class([("a", 1), ("b", 1), ("a", 2)])
+ b = self.storage_class([("a", 1), ("a", 2), ("b", 1)])
assert hash(a) != hash(b)
@@ -350,93 +369,102 @@ class TestMultiDict(_MutableMultiDictTests):
storage_class = datastructures.MultiDict
def test_multidict_pop(self):
- make_d = lambda: self.storage_class({'foo': [1, 2, 3, 4]})
+ def make_d():
+ return self.storage_class({"foo": [1, 2, 3, 4]})
+
d = make_d()
- assert d.pop('foo') == 1
+ assert d.pop("foo") == 1
assert not d
d = make_d()
- assert d.pop('foo', 32) == 1
+ assert d.pop("foo", 32) == 1
assert not d
d = make_d()
- assert d.pop('foos', 32) == 32
+ assert d.pop("foos", 32) == 32
assert d
with pytest.raises(KeyError):
- d.pop('foos')
+ d.pop("foos")
def test_multidict_pop_raise_badrequestkeyerror_for_empty_list_value(self):
- mapping = [('a', 'b'), ('a', 'c')]
+ mapping = [("a", "b"), ("a", "c")]
md = self.storage_class(mapping)
- md.setlistdefault('empty', [])
+ md.setlistdefault("empty", [])
with pytest.raises(KeyError):
- md.pop('empty')
+ md.pop("empty")
def test_multidict_popitem_raise_badrequestkeyerror_for_empty_list_value(self):
mapping = []
md = self.storage_class(mapping)
- md.setlistdefault('empty', [])
+ md.setlistdefault("empty", [])
with pytest.raises(BadRequestKeyError):
md.popitem()
def test_setlistdefault(self):
md = self.storage_class()
- assert md.setlistdefault('u', [-1, -2]) == [-1, -2]
- assert md.getlist('u') == [-1, -2]
- assert md['u'] == -1
+ assert md.setlistdefault("u", [-1, -2]) == [-1, -2]
+ assert md.getlist("u") == [-1, -2]
+ assert md["u"] == -1
def test_iter_interfaces(self):
- mapping = [('a', 1), ('b', 2), ('a', 2), ('d', 3),
- ('a', 1), ('a', 3), ('d', 4), ('c', 3)]
+ mapping = [
+ ("a", 1),
+ ("b", 2),
+ ("a", 2),
+ ("d", 3),
+ ("a", 1),
+ ("a", 3),
+ ("d", 4),
+ ("c", 3),
+ ]
md = self.storage_class(mapping)
assert list(zip(md.keys(), md.listvalues())) == list(md.lists())
assert list(zip(md, iterlistvalues(md))) == list(iterlists(md))
- assert list(zip(iterkeys(md), iterlistvalues(md))) == \
- list(iterlists(md))
+ assert list(zip(iterkeys(md), iterlistvalues(md))) == list(iterlists(md))
- @pytest.mark.skipif(not PY2, reason='viewmethods work only for the 2-nd version.')
+ @pytest.mark.skipif(not PY2, reason="viewmethods work only for the 2-nd version.")
def test_view_methods(self):
- mapping = [('a', 'b'), ('a', 'c')]
+ mapping = [("a", "b"), ("a", "c")]
md = self.storage_class(mapping)
- vi = md.viewitems()
- vk = md.viewkeys()
- vv = md.viewvalues()
+ vi = md.viewitems() # noqa: B302
+ vk = md.viewkeys() # noqa: B302
+ vv = md.viewvalues() # noqa: B302
assert list(vi) == list(md.items())
assert list(vk) == list(md.keys())
assert list(vv) == list(md.values())
- md['k'] = 'n'
+ md["k"] = "n"
assert list(vi) == list(md.items())
assert list(vk) == list(md.keys())
assert list(vv) == list(md.values())
- @pytest.mark.skipif(not PY2, reason='viewmethods work only for the 2-nd version.')
+ @pytest.mark.skipif(not PY2, reason="viewmethods work only for the 2-nd version.")
def test_viewitems_with_multi(self):
- mapping = [('a', 'b'), ('a', 'c')]
+ mapping = [("a", "b"), ("a", "c")]
md = self.storage_class(mapping)
- vi = md.viewitems(multi=True)
+ vi = md.viewitems(multi=True) # noqa: B302
assert list(vi) == list(md.items(multi=True))
- md['k'] = 'n'
+ md["k"] = "n"
assert list(vi) == list(md.items(multi=True))
def test_getitem_raise_badrequestkeyerror_for_empty_list_value(self):
- mapping = [('a', 'b'), ('a', 'c')]
+ mapping = [("a", "b"), ("a", "c")]
md = self.storage_class(mapping)
- md.setlistdefault('empty', [])
+ md.setlistdefault("empty", [])
with pytest.raises(KeyError):
- md['empty']
+ md["empty"]
class TestOrderedMultiDict(_MutableMultiDictTests):
@@ -447,90 +475,93 @@ class TestOrderedMultiDict(_MutableMultiDictTests):
d = cls()
assert not d
- d.add('foo', 'bar')
+ d.add("foo", "bar")
assert len(d) == 1
- d.add('foo', 'baz')
+ d.add("foo", "baz")
assert len(d) == 1
- assert list(iteritems(d)) == [('foo', 'bar')]
- assert list(d) == ['foo']
- assert list(iteritems(d, multi=True)) == \
- [('foo', 'bar'), ('foo', 'baz')]
- del d['foo']
+ assert list(iteritems(d)) == [("foo", "bar")]
+ assert list(d) == ["foo"]
+ assert list(iteritems(d, multi=True)) == [("foo", "bar"), ("foo", "baz")]
+ del d["foo"]
assert not d
assert len(d) == 0
assert list(d) == []
- d.update([('foo', 1), ('foo', 2), ('bar', 42)])
- d.add('foo', 3)
- assert d.getlist('foo') == [1, 2, 3]
- assert d.getlist('bar') == [42]
- assert list(iteritems(d)) == [('foo', 1), ('bar', 42)]
+ d.update([("foo", 1), ("foo", 2), ("bar", 42)])
+ d.add("foo", 3)
+ assert d.getlist("foo") == [1, 2, 3]
+ assert d.getlist("bar") == [42]
+ assert list(iteritems(d)) == [("foo", 1), ("bar", 42)]
- expected = ['foo', 'bar']
+ expected = ["foo", "bar"]
assert list(d.keys()) == expected
assert list(d) == expected
assert list(iterkeys(d)) == expected
- assert list(iteritems(d, multi=True)) == \
- [('foo', 1), ('foo', 2), ('bar', 42), ('foo', 3)]
+ assert list(iteritems(d, multi=True)) == [
+ ("foo", 1),
+ ("foo", 2),
+ ("bar", 42),
+ ("foo", 3),
+ ]
assert len(d) == 2
- assert d.pop('foo') == 1
- assert d.pop('blafasel', None) is None
- assert d.pop('blafasel', 42) == 42
+ assert d.pop("foo") == 1
+ assert d.pop("blafasel", None) is None
+ assert d.pop("blafasel", 42) == 42
assert len(d) == 1
- assert d.poplist('bar') == [42]
+ assert d.poplist("bar") == [42]
assert not d
- d.get('missingkey') is None
+ d.get("missingkey") is None
- d.add('foo', 42)
- d.add('foo', 23)
- d.add('bar', 2)
- d.add('foo', 42)
+ d.add("foo", 42)
+ d.add("foo", 23)
+ d.add("bar", 2)
+ d.add("foo", 42)
assert d == datastructures.MultiDict(d)
id = self.storage_class(d)
assert d == id
- d.add('foo', 2)
+ d.add("foo", 2)
assert d != id
- d.update({'blah': [1, 2, 3]})
- assert d['blah'] == 1
- assert d.getlist('blah') == [1, 2, 3]
+ d.update({"blah": [1, 2, 3]})
+ assert d["blah"] == 1
+ assert d.getlist("blah") == [1, 2, 3]
# setlist works
d = self.storage_class()
- d['foo'] = 42
- d.setlist('foo', [1, 2])
- assert d.getlist('foo') == [1, 2]
+ d["foo"] = 42
+ d.setlist("foo", [1, 2])
+ assert d.getlist("foo") == [1, 2]
with pytest.raises(BadRequestKeyError):
- d.pop('missing')
+ d.pop("missing")
with pytest.raises(BadRequestKeyError):
- d['missing']
+ d["missing"]
# popping
d = self.storage_class()
- d.add('foo', 23)
- d.add('foo', 42)
- d.add('foo', 1)
- assert d.popitem() == ('foo', 23)
+ d.add("foo", 23)
+ d.add("foo", 42)
+ d.add("foo", 1)
+ assert d.popitem() == ("foo", 23)
with pytest.raises(BadRequestKeyError):
d.popitem()
assert not d
- d.add('foo', 23)
- d.add('foo', 42)
- d.add('foo', 1)
- assert d.popitemlist() == ('foo', [23, 42, 1])
+ d.add("foo", 23)
+ d.add("foo", 42)
+ d.add("foo", 1)
+ assert d.popitemlist() == ("foo", [23, 42, 1])
with pytest.raises(BadRequestKeyError):
d.popitemlist()
# Unhashable
d = self.storage_class()
- d.add('foo', 23)
+ d.add("foo", 23)
pytest.raises(TypeError, hash, d)
def test_iterables(self):
@@ -538,92 +569,91 @@ class TestOrderedMultiDict(_MutableMultiDictTests):
b = datastructures.MultiDict((("key_b", "value_b"),))
ab = datastructures.CombinedMultiDict((a, b))
- assert sorted(ab.lists()) == [('key_a', ['value_a']), ('key_b', ['value_b'])]
- assert sorted(ab.listvalues()) == [['value_a'], ['value_b']]
+ assert sorted(ab.lists()) == [("key_a", ["value_a"]), ("key_b", ["value_b"])]
+ assert sorted(ab.listvalues()) == [["value_a"], ["value_b"]]
assert sorted(ab.keys()) == ["key_a", "key_b"]
- assert sorted(iterlists(ab)) == [('key_a', ['value_a']), ('key_b', ['value_b'])]
- assert sorted(iterlistvalues(ab)) == [['value_a'], ['value_b']]
+ assert sorted(iterlists(ab)) == [("key_a", ["value_a"]), ("key_b", ["value_b"])]
+ assert sorted(iterlistvalues(ab)) == [["value_a"], ["value_b"]]
assert sorted(iterkeys(ab)) == ["key_a", "key_b"]
def test_get_description(self):
data = datastructures.OrderedMultiDict()
with pytest.raises(BadRequestKeyError) as exc_info:
- data['baz']
+ data["baz"]
- assert 'baz' in exc_info.value.get_description()
+ assert "baz" in exc_info.value.get_description()
with pytest.raises(BadRequestKeyError) as exc_info:
- data.pop('baz')
+ data.pop("baz")
- assert 'baz' in exc_info.value.get_description()
+ assert "baz" in exc_info.value.get_description()
exc_info.value.args = ()
- assert 'baz' not in exc_info.value.get_description()
+ assert "baz" not in exc_info.value.get_description()
class TestTypeConversionDict(object):
storage_class = datastructures.TypeConversionDict
def test_value_conversion(self):
- d = self.storage_class(foo='1')
- assert d.get('foo', type=int) == 1
+ d = self.storage_class(foo="1")
+ assert d.get("foo", type=int) == 1
def test_return_default_when_conversion_is_not_possible(self):
- d = self.storage_class(foo='bar')
- assert d.get('foo', default=-1, type=int) == -1
+ d = self.storage_class(foo="bar")
+ assert d.get("foo", default=-1, type=int) == -1
def test_propagate_exceptions_in_conversion(self):
- d = self.storage_class(foo='bar')
- switch = {'a': 1}
+ d = self.storage_class(foo="bar")
+ switch = {"a": 1}
with pytest.raises(KeyError):
- d.get('foo', type=lambda x: switch[x])
+ d.get("foo", type=lambda x: switch[x])
class TestCombinedMultiDict(object):
storage_class = datastructures.CombinedMultiDict
def test_basic_interface(self):
- d1 = datastructures.MultiDict([('foo', '1')])
- d2 = datastructures.MultiDict([('bar', '2'), ('bar', '3')])
+ d1 = datastructures.MultiDict([("foo", "1")])
+ d2 = datastructures.MultiDict([("bar", "2"), ("bar", "3")])
d = self.storage_class([d1, d2])
# lookup
- assert d['foo'] == '1'
- assert d['bar'] == '2'
- assert d.getlist('bar') == ['2', '3']
+ assert d["foo"] == "1"
+ assert d["bar"] == "2"
+ assert d.getlist("bar") == ["2", "3"]
- assert sorted(d.items()) == [('bar', '2'), ('foo', '1')]
- assert sorted(d.items(multi=True)) == \
- [('bar', '2'), ('bar', '3'), ('foo', '1')]
- assert 'missingkey' not in d
- assert 'foo' in d
+ assert sorted(d.items()) == [("bar", "2"), ("foo", "1")]
+ assert sorted(d.items(multi=True)) == [("bar", "2"), ("bar", "3"), ("foo", "1")]
+ assert "missingkey" not in d
+ assert "foo" in d
# type lookup
- assert d.get('foo', type=int) == 1
- assert d.getlist('bar', type=int) == [2, 3]
+ assert d.get("foo", type=int) == 1
+ assert d.getlist("bar", type=int) == [2, 3]
# get key errors for missing stuff
with pytest.raises(KeyError):
- d['missing']
+ d["missing"]
# make sure that they are immutable
with pytest.raises(TypeError):
- d['foo'] = 'blub'
+ d["foo"] = "blub"
# copies are mutable
d = d.copy()
- d['foo'] = 'blub'
+ d["foo"] = "blub"
# make sure lists merges
md1 = datastructures.MultiDict((("foo", "bar"),))
md2 = datastructures.MultiDict((("foo", "blafasel"),))
x = self.storage_class((md1, md2))
- assert list(iterlists(x)) == [('foo', ['bar', 'blafasel'])]
+ assert list(iterlists(x)) == [("foo", ["bar", "blafasel"])]
def test_length(self):
- d1 = datastructures.MultiDict([('foo', '1')])
- d2 = datastructures.MultiDict([('bar', '2')])
+ d1 = datastructures.MultiDict([("foo", "1")])
+ d2 = datastructures.MultiDict([("bar", "2")])
assert len(d1) == len(d2) == 1
d = self.storage_class([d1, d2])
assert len(d) == 2
@@ -637,143 +667,135 @@ class TestHeaders(object):
def test_basic_interface(self):
headers = self.storage_class()
- headers.add('Content-Type', 'text/plain')
- headers.add('X-Foo', 'bar')
- assert 'x-Foo' in headers
- assert 'Content-type' in headers
+ headers.add("Content-Type", "text/plain")
+ headers.add("X-Foo", "bar")
+ assert "x-Foo" in headers
+ assert "Content-type" in headers
- headers['Content-Type'] = 'foo/bar'
- assert headers['Content-Type'] == 'foo/bar'
- assert len(headers.getlist('Content-Type')) == 1
+ headers["Content-Type"] = "foo/bar"
+ assert headers["Content-Type"] == "foo/bar"
+ assert len(headers.getlist("Content-Type")) == 1
# list conversion
- assert headers.to_wsgi_list() == [
- ('Content-Type', 'foo/bar'),
- ('X-Foo', 'bar')
- ]
- assert str(headers) == (
- "Content-Type: foo/bar\r\n"
- "X-Foo: bar\r\n"
- "\r\n"
- )
+ assert headers.to_wsgi_list() == [("Content-Type", "foo/bar"), ("X-Foo", "bar")]
+ assert str(headers) == "Content-Type: foo/bar\r\nX-Foo: bar\r\n\r\n"
assert str(self.storage_class()) == "\r\n"
# extended add
- headers.add('Content-Disposition', 'attachment', filename='foo')
- assert headers['Content-Disposition'] == 'attachment; filename=foo'
+ headers.add("Content-Disposition", "attachment", filename="foo")
+ assert headers["Content-Disposition"] == "attachment; filename=foo"
- headers.add('x', 'y', z='"')
- assert headers['x'] == r'y; z="\""'
+ headers.add("x", "y", z='"')
+ assert headers["x"] == r'y; z="\""'
def test_defaults_and_conversion(self):
# defaults
- headers = self.storage_class([
- ('Content-Type', 'text/plain'),
- ('X-Foo', 'bar'),
- ('X-Bar', '1'),
- ('X-Bar', '2')
- ])
- assert headers.getlist('x-bar') == ['1', '2']
- assert headers.get('x-Bar') == '1'
- assert headers.get('Content-Type') == 'text/plain'
-
- assert headers.setdefault('X-Foo', 'nope') == 'bar'
- assert headers.setdefault('X-Bar', 'nope') == '1'
- assert headers.setdefault('X-Baz', 'quux') == 'quux'
- assert headers.setdefault('X-Baz', 'nope') == 'quux'
- headers.pop('X-Baz')
+ headers = self.storage_class(
+ [
+ ("Content-Type", "text/plain"),
+ ("X-Foo", "bar"),
+ ("X-Bar", "1"),
+ ("X-Bar", "2"),
+ ]
+ )
+ assert headers.getlist("x-bar") == ["1", "2"]
+ assert headers.get("x-Bar") == "1"
+ assert headers.get("Content-Type") == "text/plain"
+
+ assert headers.setdefault("X-Foo", "nope") == "bar"
+ assert headers.setdefault("X-Bar", "nope") == "1"
+ assert headers.setdefault("X-Baz", "quux") == "quux"
+ assert headers.setdefault("X-Baz", "nope") == "quux"
+ headers.pop("X-Baz")
# type conversion
- assert headers.get('x-bar', type=int) == 1
- assert headers.getlist('x-bar', type=int) == [1, 2]
+ assert headers.get("x-bar", type=int) == 1
+ assert headers.getlist("x-bar", type=int) == [1, 2]
# list like operations
- assert headers[0] == ('Content-Type', 'text/plain')
- assert headers[:1] == self.storage_class([('Content-Type', 'text/plain')])
+ assert headers[0] == ("Content-Type", "text/plain")
+ assert headers[:1] == self.storage_class([("Content-Type", "text/plain")])
del headers[:2]
del headers[-1]
- assert headers == self.storage_class([('X-Bar', '1')])
+ assert headers == self.storage_class([("X-Bar", "1")])
def test_copying(self):
- a = self.storage_class([('foo', 'bar')])
+ a = self.storage_class([("foo", "bar")])
b = a.copy()
- a.add('foo', 'baz')
- assert a.getlist('foo') == ['bar', 'baz']
- assert b.getlist('foo') == ['bar']
+ a.add("foo", "baz")
+ assert a.getlist("foo") == ["bar", "baz"]
+ assert b.getlist("foo") == ["bar"]
def test_popping(self):
- headers = self.storage_class([('a', 1)])
- assert headers.pop('a') == 1
- assert headers.pop('b', 2) == 2
+ headers = self.storage_class([("a", 1)])
+ assert headers.pop("a") == 1
+ assert headers.pop("b", 2) == 2
with pytest.raises(KeyError):
- headers.pop('c')
+ headers.pop("c")
def test_set_arguments(self):
a = self.storage_class()
- a.set('Content-Disposition', 'useless')
- a.set('Content-Disposition', 'attachment', filename='foo')
- assert a['Content-Disposition'] == 'attachment; filename=foo'
+ a.set("Content-Disposition", "useless")
+ a.set("Content-Disposition", "attachment", filename="foo")
+ assert a["Content-Disposition"] == "attachment; filename=foo"
def test_reject_newlines(self):
h = self.storage_class()
- for variation in 'foo\nbar', 'foo\r\nbar', 'foo\rbar':
+ for variation in "foo\nbar", "foo\r\nbar", "foo\rbar":
with pytest.raises(ValueError):
- h['foo'] = variation
+ h["foo"] = variation
with pytest.raises(ValueError):
- h.add('foo', variation)
+ h.add("foo", variation)
with pytest.raises(ValueError):
- h.add('foo', 'test', option=variation)
+ h.add("foo", "test", option=variation)
with pytest.raises(ValueError):
- h.set('foo', variation)
+ h.set("foo", variation)
with pytest.raises(ValueError):
- h.set('foo', 'test', option=variation)
+ h.set("foo", "test", option=variation)
def test_slicing(self):
# there's nothing wrong with these being native strings
# Headers doesn't care about the data types
h = self.storage_class()
- h.set('X-Foo-Poo', 'bleh')
- h.set('Content-Type', 'application/whocares')
- h.set('X-Forwarded-For', '192.168.0.123')
- h[:] = [(k, v) for k, v in h if k.startswith(u'X-')]
- assert list(h) == [
- ('X-Foo-Poo', 'bleh'),
- ('X-Forwarded-For', '192.168.0.123')
- ]
+ h.set("X-Foo-Poo", "bleh")
+ h.set("Content-Type", "application/whocares")
+ h.set("X-Forwarded-For", "192.168.0.123")
+ h[:] = [(k, v) for k, v in h if k.startswith(u"X-")]
+ assert list(h) == [("X-Foo-Poo", "bleh"), ("X-Forwarded-For", "192.168.0.123")]
def test_bytes_operations(self):
h = self.storage_class()
- h.set('X-Foo-Poo', 'bleh')
- h.set('X-Whoops', b'\xff')
- h.set(b'X-Bytes', b'something')
+ h.set("X-Foo-Poo", "bleh")
+ h.set("X-Whoops", b"\xff")
+ h.set(b"X-Bytes", b"something")
- assert h.get('x-foo-poo', as_bytes=True) == b'bleh'
- assert h.get('x-whoops', as_bytes=True) == b'\xff'
- assert h.get('x-bytes') == 'something'
+ assert h.get("x-foo-poo", as_bytes=True) == b"bleh"
+ assert h.get("x-whoops", as_bytes=True) == b"\xff"
+ assert h.get("x-bytes") == "something"
def test_to_wsgi_list(self):
h = self.storage_class()
- h.set(u'Key', u'Value')
+ h.set(u"Key", u"Value")
for key, value in h.to_wsgi_list():
if PY2:
- strict_eq(key, b'Key')
- strict_eq(value, b'Value')
+ strict_eq(key, b"Key")
+ strict_eq(value, b"Value")
else:
- strict_eq(key, u'Key')
- strict_eq(value, u'Value')
+ strict_eq(key, u"Key")
+ strict_eq(value, u"Value")
def test_to_wsgi_list_bytes(self):
h = self.storage_class()
- h.set(b'Key', b'Value')
+ h.set(b"Key", b"Value")
for key, value in h.to_wsgi_list():
if PY2:
- strict_eq(key, b'Key')
- strict_eq(value, b'Value')
+ strict_eq(key, b"Key")
+ strict_eq(value, b"Value")
else:
- strict_eq(key, u'Key')
- strict_eq(value, u'Value')
+ strict_eq(key, u"Key")
+ strict_eq(value, u"Value")
class TestEnvironHeaders(object):
@@ -783,64 +805,53 @@ class TestEnvironHeaders(object):
# this happens in multiple WSGI servers because they
# use a vary naive way to convert the headers;
broken_env = {
- 'HTTP_CONTENT_TYPE': 'text/html',
- 'CONTENT_TYPE': 'text/html',
- 'HTTP_CONTENT_LENGTH': '0',
- 'CONTENT_LENGTH': '0',
- 'HTTP_ACCEPT': '*',
- 'wsgi.version': (1, 0)
+ "HTTP_CONTENT_TYPE": "text/html",
+ "CONTENT_TYPE": "text/html",
+ "HTTP_CONTENT_LENGTH": "0",
+ "CONTENT_LENGTH": "0",
+ "HTTP_ACCEPT": "*",
+ "wsgi.version": (1, 0),
}
headers = self.storage_class(broken_env)
assert headers
assert len(headers) == 3
assert sorted(headers) == [
- ('Accept', '*'),
- ('Content-Length', '0'),
- ('Content-Type', 'text/html')
+ ("Accept", "*"),
+ ("Content-Length", "0"),
+ ("Content-Type", "text/html"),
]
- assert not self.storage_class({'wsgi.version': (1, 0)})
- assert len(self.storage_class({'wsgi.version': (1, 0)})) == 0
+ assert not self.storage_class({"wsgi.version": (1, 0)})
+ assert len(self.storage_class({"wsgi.version": (1, 0)})) == 0
assert 42 not in headers
def test_skip_empty_special_vars(self):
- env = {
- 'HTTP_X_FOO': '42',
- 'CONTENT_TYPE': '',
- 'CONTENT_LENGTH': '',
- }
+ env = {"HTTP_X_FOO": "42", "CONTENT_TYPE": "", "CONTENT_LENGTH": ""}
headers = self.storage_class(env)
- assert dict(headers) == {'X-Foo': '42'}
+ assert dict(headers) == {"X-Foo": "42"}
- env = {
- 'HTTP_X_FOO': '42',
- 'CONTENT_TYPE': '',
- 'CONTENT_LENGTH': '0',
- }
+ env = {"HTTP_X_FOO": "42", "CONTENT_TYPE": "", "CONTENT_LENGTH": "0"}
headers = self.storage_class(env)
- assert dict(headers) == {'X-Foo': '42', 'Content-Length': '0'}
+ assert dict(headers) == {"X-Foo": "42", "Content-Length": "0"}
def test_return_type_is_unicode(self):
# environ contains native strings; we return unicode
- headers = self.storage_class({
- 'HTTP_FOO': '\xe2\x9c\x93',
- 'CONTENT_TYPE': 'text/plain',
- })
- assert headers['Foo'] == u"\xe2\x9c\x93"
- assert isinstance(headers['Foo'], text_type)
- assert isinstance(headers['Content-Type'], text_type)
+ headers = self.storage_class(
+ {"HTTP_FOO": "\xe2\x9c\x93", "CONTENT_TYPE": "text/plain"}
+ )
+ assert headers["Foo"] == u"\xe2\x9c\x93"
+ assert isinstance(headers["Foo"], text_type)
+ assert isinstance(headers["Content-Type"], text_type)
iter_output = dict(iter(headers))
- assert iter_output['Foo'] == u"\xe2\x9c\x93"
- assert isinstance(iter_output['Foo'], text_type)
- assert isinstance(iter_output['Content-Type'], text_type)
+ assert iter_output["Foo"] == u"\xe2\x9c\x93"
+ assert isinstance(iter_output["Foo"], text_type)
+ assert isinstance(iter_output["Content-Type"], text_type)
def test_bytes_operations(self):
- foo_val = '\xff'
- h = self.storage_class({
- 'HTTP_X_FOO': foo_val
- })
+ foo_val = "\xff"
+ h = self.storage_class({"HTTP_X_FOO": foo_val})
- assert h.get('x-foo', as_bytes=True) == b'\xff'
- assert h.get('x-foo') == u'\xff'
+ assert h.get("x-foo", as_bytes=True) == b"\xff"
+ assert h.get("x-foo") == u"\xff"
class TestHeaderSet(object):
@@ -848,21 +859,21 @@ class TestHeaderSet(object):
def test_basic_interface(self):
hs = self.storage_class()
- hs.add('foo')
- hs.add('bar')
- assert 'Bar' in hs
- assert hs.find('foo') == 0
- assert hs.find('BAR') == 1
- assert hs.find('baz') < 0
- hs.discard('missing')
- hs.discard('foo')
- assert hs.find('foo') < 0
- assert hs.find('bar') == 0
+ hs.add("foo")
+ hs.add("bar")
+ assert "Bar" in hs
+ assert hs.find("foo") == 0
+ assert hs.find("BAR") == 1
+ assert hs.find("baz") < 0
+ hs.discard("missing")
+ hs.discard("foo")
+ assert hs.find("foo") < 0
+ assert hs.find("bar") == 0
with pytest.raises(IndexError):
- hs.index('missing')
+ hs.index("missing")
- assert hs.index('bar') == 0
+ assert hs.index("bar") == 0
assert hs
hs.clear()
assert not hs
@@ -910,47 +921,44 @@ class TestCallbackDict(object):
def test_callback_dict_reads(self):
assert_calls, func = make_call_asserter()
- initial = {'a': 'foo', 'b': 'bar'}
+ initial = {"a": "foo", "b": "bar"}
dct = self.storage_class(initial=initial, on_update=func)
- with assert_calls(0, 'callback triggered by read-only method'):
+ with assert_calls(0, "callback triggered by read-only method"):
# read-only methods
- dct['a']
- dct.get('a')
- pytest.raises(KeyError, lambda: dct['x'])
- 'a' in dct
+ dct["a"]
+ dct.get("a")
+ pytest.raises(KeyError, lambda: dct["x"])
+ "a" in dct
list(iter(dct))
dct.copy()
- with assert_calls(0, 'callback triggered without modification'):
+ with assert_calls(0, "callback triggered without modification"):
# methods that may write but don't
- dct.pop('z', None)
- dct.setdefault('a')
+ dct.pop("z", None)
+ dct.setdefault("a")
def test_callback_dict_writes(self):
assert_calls, func = make_call_asserter()
- initial = {'a': 'foo', 'b': 'bar'}
+ initial = {"a": "foo", "b": "bar"}
dct = self.storage_class(initial=initial, on_update=func)
- with assert_calls(8, 'callback not triggered by write method'):
+ with assert_calls(8, "callback not triggered by write method"):
# always-write methods
- dct['z'] = 123
- dct['z'] = 123 # must trigger again
- del dct['z']
- dct.pop('b', None)
- dct.setdefault('x')
+ dct["z"] = 123
+ dct["z"] = 123 # must trigger again
+ del dct["z"]
+ dct.pop("b", None)
+ dct.setdefault("x")
dct.popitem()
dct.update([])
dct.clear()
- with assert_calls(0, 'callback triggered by failed del'):
- pytest.raises(KeyError, lambda: dct.__delitem__('x'))
- with assert_calls(0, 'callback triggered by failed pop'):
- pytest.raises(KeyError, lambda: dct.pop('x'))
+ with assert_calls(0, "callback triggered by failed del"):
+ pytest.raises(KeyError, lambda: dct.__delitem__("x"))
+ with assert_calls(0, "callback triggered by failed pop"):
+ pytest.raises(KeyError, lambda: dct.pop("x"))
class TestCacheControl(object):
-
def test_repr(self):
- cc = datastructures.RequestCacheControl(
- [("max-age", "0"), ("private", "True")],
- )
+ cc = datastructures.RequestCacheControl([("max-age", "0"), ("private", "True")])
assert repr(cc) == "<RequestCacheControl max-age='0' private='True'>"
@@ -958,117 +966,118 @@ class TestAccept(object):
storage_class = datastructures.Accept
def test_accept_basic(self):
- accept = self.storage_class([('tinker', 0), ('tailor', 0.333),
- ('soldier', 0.667), ('sailor', 1)])
+ accept = self.storage_class(
+ [("tinker", 0), ("tailor", 0.333), ("soldier", 0.667), ("sailor", 1)]
+ )
# check __getitem__ on indices
- assert accept[3] == ('tinker', 0)
- assert accept[2] == ('tailor', 0.333)
- assert accept[1] == ('soldier', 0.667)
- assert accept[0], ('sailor', 1)
+ assert accept[3] == ("tinker", 0)
+ assert accept[2] == ("tailor", 0.333)
+ assert accept[1] == ("soldier", 0.667)
+ assert accept[0], ("sailor", 1)
# check __getitem__ on string
- assert accept['tinker'] == 0
- assert accept['tailor'] == 0.333
- assert accept['soldier'] == 0.667
- assert accept['sailor'] == 1
- assert accept['spy'] == 0
+ assert accept["tinker"] == 0
+ assert accept["tailor"] == 0.333
+ assert accept["soldier"] == 0.667
+ assert accept["sailor"] == 1
+ assert accept["spy"] == 0
# check quality method
- assert accept.quality('tinker') == 0
- assert accept.quality('tailor') == 0.333
- assert accept.quality('soldier') == 0.667
- assert accept.quality('sailor') == 1
- assert accept.quality('spy') == 0
+ assert accept.quality("tinker") == 0
+ assert accept.quality("tailor") == 0.333
+ assert accept.quality("soldier") == 0.667
+ assert accept.quality("sailor") == 1
+ assert accept.quality("spy") == 0
# check __contains__
- assert 'sailor' in accept
- assert 'spy' not in accept
+ assert "sailor" in accept
+ assert "spy" not in accept
# check index method
- assert accept.index('tinker') == 3
- assert accept.index('tailor') == 2
- assert accept.index('soldier') == 1
- assert accept.index('sailor') == 0
+ assert accept.index("tinker") == 3
+ assert accept.index("tailor") == 2
+ assert accept.index("soldier") == 1
+ assert accept.index("sailor") == 0
with pytest.raises(ValueError):
- accept.index('spy')
+ accept.index("spy")
# check find method
- assert accept.find('tinker') == 3
- assert accept.find('tailor') == 2
- assert accept.find('soldier') == 1
- assert accept.find('sailor') == 0
- assert accept.find('spy') == -1
+ assert accept.find("tinker") == 3
+ assert accept.find("tailor") == 2
+ assert accept.find("soldier") == 1
+ assert accept.find("sailor") == 0
+ assert accept.find("spy") == -1
# check to_header method
- assert accept.to_header() == \
- 'sailor,soldier;q=0.667,tailor;q=0.333,tinker;q=0'
+ assert accept.to_header() == "sailor,soldier;q=0.667,tailor;q=0.333,tinker;q=0"
# check best_match method
- assert accept.best_match(['tinker', 'tailor', 'soldier', 'sailor'],
- default=None) == 'sailor'
- assert accept.best_match(['tinker', 'tailor', 'soldier'],
- default=None) == 'soldier'
- assert accept.best_match(['tinker', 'tailor'], default=None) == \
- 'tailor'
- assert accept.best_match(['tinker'], default=None) is None
- assert accept.best_match(['tinker'], default='x') == 'x'
+ assert (
+ accept.best_match(["tinker", "tailor", "soldier", "sailor"], default=None)
+ == "sailor"
+ )
+ assert (
+ accept.best_match(["tinker", "tailor", "soldier"], default=None)
+ == "soldier"
+ )
+ assert accept.best_match(["tinker", "tailor"], default=None) == "tailor"
+ assert accept.best_match(["tinker"], default=None) is None
+ assert accept.best_match(["tinker"], default="x") == "x"
def test_accept_wildcard(self):
- accept = self.storage_class([('*', 0), ('asterisk', 1)])
- assert '*' in accept
- assert accept.best_match(['asterisk', 'star'], default=None) == \
- 'asterisk'
- assert accept.best_match(['star'], default=None) is None
+ accept = self.storage_class([("*", 0), ("asterisk", 1)])
+ assert "*" in accept
+ assert accept.best_match(["asterisk", "star"], default=None) == "asterisk"
+ assert accept.best_match(["star"], default=None) is None
def test_accept_keep_order(self):
- accept = self.storage_class([('*', 1)])
+ accept = self.storage_class([("*", 1)])
assert accept.best_match(["alice", "bob"]) == "alice"
assert accept.best_match(["bob", "alice"]) == "bob"
- accept = self.storage_class([('alice', 1), ('bob', 1)])
+ accept = self.storage_class([("alice", 1), ("bob", 1)])
assert accept.best_match(["alice", "bob"]) == "alice"
assert accept.best_match(["bob", "alice"]) == "bob"
def test_accept_wildcard_specificity(self):
- accept = self.storage_class([('asterisk', 0), ('star', 0.5), ('*', 1)])
- assert accept.best_match(['star', 'asterisk'], default=None) == 'star'
- assert accept.best_match(['asterisk', 'star'], default=None) == 'star'
- assert accept.best_match(['asterisk', 'times'], default=None) == \
- 'times'
- assert accept.best_match(['asterisk'], default=None) is None
+ accept = self.storage_class([("asterisk", 0), ("star", 0.5), ("*", 1)])
+ assert accept.best_match(["star", "asterisk"], default=None) == "star"
+ assert accept.best_match(["asterisk", "star"], default=None) == "star"
+ assert accept.best_match(["asterisk", "times"], default=None) == "times"
+ assert accept.best_match(["asterisk"], default=None) is None
class TestMIMEAccept(object):
storage_class = datastructures.MIMEAccept
def test_accept_wildcard_subtype(self):
- accept = self.storage_class([('text/*', 1)])
- assert accept.best_match(['text/html'], default=None) == 'text/html'
- assert accept.best_match(['image/png', 'text/plain']) == 'text/plain'
- assert accept.best_match(['image/png'], default=None) is None
+ accept = self.storage_class([("text/*", 1)])
+ assert accept.best_match(["text/html"], default=None) == "text/html"
+ assert accept.best_match(["image/png", "text/plain"]) == "text/plain"
+ assert accept.best_match(["image/png"], default=None) is None
def test_accept_wildcard_specificity(self):
- accept = self.storage_class([('*/*', 1), ('text/html', 1)])
- assert accept.best_match(['image/png', 'text/html']) == 'text/html'
- assert accept.best_match(['image/png', 'text/plain']) == 'image/png'
- accept = self.storage_class([('*/*', 1), ('text/html', 1),
- ('image/*', 1)])
- assert accept.best_match(['image/png', 'text/html']) == 'text/html'
- assert accept.best_match(['text/plain', 'image/png']) == 'image/png'
+ accept = self.storage_class([("*/*", 1), ("text/html", 1)])
+ assert accept.best_match(["image/png", "text/html"]) == "text/html"
+ assert accept.best_match(["image/png", "text/plain"]) == "image/png"
+ accept = self.storage_class([("*/*", 1), ("text/html", 1), ("image/*", 1)])
+ assert accept.best_match(["image/png", "text/html"]) == "text/html"
+ assert accept.best_match(["text/plain", "image/png"]) == "image/png"
class TestFileStorage(object):
storage_class = datastructures.FileStorage
def test_mimetype_always_lowercase(self):
- file_storage = self.storage_class(content_type='APPLICATION/JSON')
- assert file_storage.mimetype == 'application/json'
+ file_storage = self.storage_class(content_type="APPLICATION/JSON")
+ assert file_storage.mimetype == "application/json"
def test_bytes_proper_sentinel(self):
# ensure we iterate over new lines and don't enter into an infinite loop
import io
+
unicode_storage = self.storage_class(io.StringIO(u"one\ntwo"))
- for idx, line in enumerate(unicode_storage):
+ for idx, _line in enumerate(unicode_storage):
assert idx < 2
assert idx == 1
binary_storage = self.storage_class(io.BytesIO(b"one\ntwo"))
- for idx, line in enumerate(binary_storage):
+ for idx, _line in enumerate(binary_storage):
assert idx < 2
assert idx == 1
- @pytest.mark.skipif(PY2, reason='io.IOBase is only needed in PY3.')
+ @pytest.mark.skipif(PY2, reason="io.IOBase is only needed in PY3.")
@pytest.mark.parametrize("stream", (tempfile.SpooledTemporaryFile, io.BytesIO))
def test_proxy_can_access_stream_attrs(self, stream):
"""``SpooledTemporaryFile`` doesn't implement some of
@@ -1084,9 +1093,7 @@ class TestFileStorage(object):
assert hasattr(file_storage, name)
-@pytest.mark.parametrize(
- "ranges", ([(0, 1), (-5, None)], [(5, None)])
-)
+@pytest.mark.parametrize("ranges", ([(0, 1), (-5, None)], [(5, None)]))
def test_range_to_header(ranges):
header = Range("byes", ranges).to_header()
r = http.parse_range_header(header)
diff --git a/tests/test_debug.py b/tests/test_debug.py
index 43d9ae6d..fa94b653 100644
--- a/tests/test_debug.py
+++ b/tests/test_debug.py
@@ -8,43 +8,59 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-import sys
-import re
import io
+import re
+import sys
import pytest
import requests
-from werkzeug.debug import get_machine_id, DebuggedApplication
-from werkzeug.debug.repr import debug_repr, DebugReprGenerator, \
- dump, helper
+from werkzeug._compat import PY2
+from werkzeug.debug import DebuggedApplication
+from werkzeug.debug import get_machine_id
from werkzeug.debug.console import HTMLStringO
+from werkzeug.debug.repr import debug_repr
+from werkzeug.debug.repr import DebugReprGenerator
+from werkzeug.debug.repr import dump
+from werkzeug.debug.repr import helper
from werkzeug.debug.tbtools import Traceback
from werkzeug.test import Client
-from werkzeug.wrappers import Request, Response
-from werkzeug._compat import PY2
+from werkzeug.wrappers import Request
+from werkzeug.wrappers import Response
class TestDebugRepr(object):
-
def test_basic_repr(self):
- assert debug_repr([]) == u'[]'
- assert debug_repr([1, 2]) == \
- u'[<span class="number">1</span>, <span class="number">2</span>]'
- assert debug_repr([1, 'test']) == \
- u'[<span class="number">1</span>, <span class="string">\'test\'</span>]'
- assert debug_repr([None]) == \
- u'[<span class="object">None</span>]'
+ assert debug_repr([]) == u"[]"
+ assert (
+ debug_repr([1, 2])
+ == u'[<span class="number">1</span>, <span class="number">2</span>]'
+ )
+ assert (
+ debug_repr([1, "test"])
+ == u'[<span class="number">1</span>, <span class="string">\'test\'</span>]'
+ )
+ assert debug_repr([None]) == u'[<span class="object">None</span>]'
def test_string_repr(self):
- assert debug_repr('') == u'<span class="string">\'\'</span>'
- assert debug_repr('foo') == u'<span class="string">\'foo\'</span>'
- assert debug_repr('s' * 80) == u'<span class="string">\''\
- + 's' * 69 + '<span class="extended">'\
- + 's' * 11 + '\'</span></span>'
- assert debug_repr('<' * 80) == u'<span class="string">\''\
- + '&lt;' * 69 + '<span class="extended">'\
- + '&lt;' * 11 + '\'</span></span>'
+ assert debug_repr("") == u"<span class=\"string\">''</span>"
+ assert debug_repr("foo") == u"<span class=\"string\">'foo'</span>"
+ assert (
+ debug_repr("s" * 80)
+ == u'<span class="string">\''
+ + "s" * 69
+ + '<span class="extended">'
+ + "s" * 11
+ + "'</span></span>"
+ )
+ assert (
+ debug_repr("<" * 80)
+ == u'<span class="string">\''
+ + "&lt;" * 69
+ + '<span class="extended">'
+ + "&lt;" * 11
+ + "'</span></span>"
+ )
def test_string_subclass_repr(self):
class Test(str):
@@ -52,16 +68,16 @@ class TestDebugRepr(object):
assert debug_repr(Test("foo")) == (
u'<span class="module">tests.test_debug.</span>'
- u'Test(<span class="string">\'foo\'</span>)'
+ u"Test(<span class=\"string\">'foo'</span>)"
)
@pytest.mark.skipif(not PY2, reason="u prefix on py2 only")
def test_unicode_repr(self):
- assert debug_repr(u"foo") == u'<span class="string">u\'foo\'</span>'
+ assert debug_repr(u"foo") == u"<span class=\"string\">u'foo'</span>"
@pytest.mark.skipif(PY2, reason="b prefix on py3 only")
def test_bytes_repr(self):
- assert debug_repr(b"foo") == u'<span class="string">b\'foo\'</span>'
+ assert debug_repr(b"foo") == u"<span class=\"string\">b'foo'</span>"
def test_sequence_repr(self):
assert debug_repr(list(range(20))) == (
@@ -79,51 +95,84 @@ class TestDebugRepr(object):
)
def test_mapping_repr(self):
- assert debug_repr({}) == u'{}'
- assert debug_repr({'foo': 42}) == (
+ assert debug_repr({}) == u"{}"
+ assert debug_repr({"foo": 42}) == (
u'{<span class="pair"><span class="key"><span class="string">\'foo\''
u'</span></span>: <span class="value"><span class="number">42'
- u'</span></span></span>}'
+ u"</span></span></span>}"
)
assert debug_repr(dict(zip(range(10), [None] * 10))) == (
- u'{<span class="pair"><span class="key"><span class="number">0</span></span>: <span class="value"><span class="object">None</span></span></span>, <span class="pair"><span class="key"><span class="number">1</span></span>: <span class="value"><span class="object">None</span></span></span>, <span class="pair"><span class="key"><span class="number">2</span></span>: <span class="value"><span class="object">None</span></span></span>, <span class="pair"><span class="key"><span class="number">3</span></span>: <span class="value"><span class="object">None</span></span></span>, <span class="extended"><span class="pair"><span class="key"><span class="number">4</span></span>: <span class="value"><span class="object">None</span></span></span>, <span class="pair"><span class="key"><span class="number">5</span></span>: <span class="value"><span class="object">None</span></span></span>, <span class="pair"><span class="key"><span class="number">6</span></span>: <span class="value"><span class="object">None</span></span></span>, <span class="pair"><span class="key"><span class="number">7</span></span>: <span class="value"><span class="object">None</span></span></span>, <span class="pair"><span class="key"><span class="number">8</span></span>: <span class="value"><span class="object">None</span></span></span>, <span class="pair"><span class="key"><span class="number">9</span></span>: <span class="value"><span class="object">None</span></span></span></span>}' # noqa
+ u'{<span class="pair"><span class="key"><span class="number">0'
+ u'</span></span>: <span class="value"><span class="object">None'
+ u"</span></span></span>, "
+ u'<span class="pair"><span class="key"><span class="number">1'
+ u'</span></span>: <span class="value"><span class="object">None'
+ u"</span></span></span>, "
+ u'<span class="pair"><span class="key"><span class="number">2'
+ u'</span></span>: <span class="value"><span class="object">None'
+ u"</span></span></span>, "
+ u'<span class="pair"><span class="key"><span class="number">3'
+ u'</span></span>: <span class="value"><span class="object">None'
+ u"</span></span></span>, "
+ u'<span class="extended">'
+ u'<span class="pair"><span class="key"><span class="number">4'
+ u'</span></span>: <span class="value"><span class="object">None'
+ u"</span></span></span>, "
+ u'<span class="pair"><span class="key"><span class="number">5'
+ u'</span></span>: <span class="value"><span class="object">None'
+ u"</span></span></span>, "
+ u'<span class="pair"><span class="key"><span class="number">6'
+ u'</span></span>: <span class="value"><span class="object">None'
+ u"</span></span></span>, "
+ u'<span class="pair"><span class="key"><span class="number">7'
+ u'</span></span>: <span class="value"><span class="object">None'
+ u"</span></span></span>, "
+ u'<span class="pair"><span class="key"><span class="number">8'
+ u'</span></span>: <span class="value"><span class="object">None'
+ u"</span></span></span>, "
+ u'<span class="pair"><span class="key"><span class="number">9'
+ u'</span></span>: <span class="value"><span class="object">None'
+ u"</span></span></span></span>}"
)
- assert debug_repr((1, 'zwei', u'drei')) == (
+ assert debug_repr((1, "zwei", u"drei")) == (
u'(<span class="number">1</span>, <span class="string">\''
- u'zwei\'</span>, <span class="string">%s\'drei\'</span>)'
- ) % ('u' if PY2 else '')
+ u"zwei'</span>, <span class=\"string\">%s'drei'</span>)"
+ ) % ("u" if PY2 else "")
def test_custom_repr(self):
class Foo(object):
-
def __repr__(self):
- return '<Foo 42>'
- assert debug_repr(Foo()) == \
- '<span class="object">&lt;Foo 42&gt;</span>'
+ return "<Foo 42>"
+
+ assert debug_repr(Foo()) == '<span class="object">&lt;Foo 42&gt;</span>'
def test_list_subclass_repr(self):
class MyList(list):
pass
+
assert debug_repr(MyList([1, 2])) == (
u'<span class="module">tests.test_debug.</span>MyList(['
u'<span class="number">1</span>, <span class="number">2</span>])'
)
def test_regex_repr(self):
- assert debug_repr(re.compile(r'foo\d')) == \
- u're.compile(<span class="string regex">r\'foo\\d\'</span>)'
+ assert (
+ debug_repr(re.compile(r"foo\d"))
+ == u"re.compile(<span class=\"string regex\">r'foo\\d'</span>)"
+ )
# No ur'' in Py3
# https://bugs.python.org/issue15096
- assert debug_repr(re.compile(u'foo\\d')) == (
- u're.compile(<span class="string regex">%sr\'foo\\d\'</span>)' %
- ('u' if PY2 else '')
+ assert debug_repr(re.compile(u"foo\\d")) == (
+ u"re.compile(<span class=\"string regex\">%sr'foo\\d'</span>)"
+ % ("u" if PY2 else "")
)
def test_set_repr(self):
- assert debug_repr(frozenset('x')) == \
- u'frozenset([<span class="string">\'x\'</span>])'
- assert debug_repr(set('x')) == \
- u'set([<span class="string">\'x\'</span>])'
+ assert (
+ debug_repr(frozenset("x"))
+ == u"frozenset([<span class=\"string\">'x'</span>])"
+ )
+ assert debug_repr(set("x")) == u"set([<span class=\"string\">'x'</span>])"
def test_recursive_repr(self):
a = [1]
@@ -132,13 +181,12 @@ class TestDebugRepr(object):
def test_broken_repr(self):
class Foo(object):
-
def __repr__(self):
- raise Exception('broken!')
+ raise Exception("broken!")
assert debug_repr(Foo()) == (
u'<span class="brokenrepr">&lt;broken repr (Exception: '
- u'broken!)&gt;</span>'
+ u"broken!)&gt;</span>"
)
@@ -151,25 +199,24 @@ class Foo(object):
class TestDebugHelpers(object):
-
def test_object_dumping(self):
drg = DebugReprGenerator()
out = drg.dump_object(Foo())
- assert re.search('Details for tests.test_debug.Foo object at', out)
+ assert re.search("Details for tests.test_debug.Foo object at", out)
assert re.search('<th>x.*<span class="number">42</span>', out, flags=re.DOTALL)
assert re.search('<th>y.*<span class="number">23</span>', out, flags=re.DOTALL)
assert re.search('<th>z.*<span class="number">15</span>', out, flags=re.DOTALL)
- out = drg.dump_object({'x': 42, 'y': 23})
- assert re.search('Contents of', out)
+ out = drg.dump_object({"x": 42, "y": 23})
+ assert re.search("Contents of", out)
assert re.search('<th>x.*<span class="number">42</span>', out, flags=re.DOTALL)
assert re.search('<th>y.*<span class="number">23</span>', out, flags=re.DOTALL)
- out = drg.dump_object({'x': 42, 'y': 23, 23: 11})
- assert not re.search('Contents of', out)
+ out = drg.dump_object({"x": 42, "y": 23, 23: 11})
+ assert not re.search("Contents of", out)
- out = drg.dump_locals({'x': 42, 'y': 23})
- assert re.search('Local variables in frame', out)
+ out = drg.dump_locals({"x": 42, "y": 23})
+ assert re.search("Local variables in frame", out)
assert re.search('<th>x.*<span class="number">42</span>', out, flags=re.DOTALL)
assert re.search('<th>y.*<span class="number">23</span>', out, flags=re.DOTALL)
@@ -184,11 +231,11 @@ class TestDebugHelpers(object):
finally:
sys.stdout = old
- assert 'Details for list object at' in x
+ assert "Details for list object at" in x
assert '<span class="number">1</span>' in x
- assert 'Local variables in frame' in y
- assert '<th>x' in y
- assert '<th>old' in y
+ assert "Local variables in frame" in y
+ assert "<th>x" in y
+ assert "<th>old" in y
def test_debug_help(self):
old = sys.stdout
@@ -199,8 +246,8 @@ class TestDebugHelpers(object):
finally:
sys.stdout = old
- assert 'Help on list object' in x
- assert '__delitem__' in x
+ assert "Help on list object" in x
+ assert "__delitem__" in x
@pytest.mark.skipif(PY2, reason="Python 2 doesn't have chained exceptions.")
def test_exc_divider_found_on_chained_exception(self):
@@ -208,6 +255,7 @@ class TestDebugHelpers(object):
def app(request):
def do_something():
raise ValueError("inner")
+
try:
do_something()
except ValueError:
@@ -223,7 +271,6 @@ class TestDebugHelpers(object):
class TestTraceback(object):
-
def test_log(self):
try:
1 / 0
@@ -235,12 +282,14 @@ class TestTraceback(object):
assert buffer_.getvalue().strip() == traceback.plaintext.strip()
def test_sourcelines_encoding(self):
- source = (u'# -*- coding: latin1 -*-\n\n'
- u'def foo():\n'
- u' """höhö"""\n'
- u' 1 / 0\n'
- u'foo()').encode('latin1')
- code = compile(source, filename='lol.py', mode='exec')
+ source = (
+ u"# -*- coding: latin1 -*-\n\n"
+ u"def foo():\n"
+ u' """höhö"""\n'
+ u" 1 / 0\n"
+ u"foo()"
+ ).encode("latin1")
+ code = compile(source, filename="lol.py", mode="exec")
try:
eval(code)
except ZeroDivisionError:
@@ -248,23 +297,23 @@ class TestTraceback(object):
frames = traceback.frames
assert len(frames) == 3
- assert frames[1].filename == 'lol.py'
- assert frames[2].filename == 'lol.py'
+ assert frames[1].filename == "lol.py"
+ assert frames[2].filename == "lol.py"
class Loader(object):
-
def get_source(self, module):
return source
frames[1].loader = frames[2].loader = Loader()
assert frames[1].sourcelines == frames[2].sourcelines
- assert [line.code for line in frames[1].get_annotated_lines()] == \
- [line.code for line in frames[2].get_annotated_lines()]
- assert u'höhö' in frames[1].sourcelines[3]
+ assert [line.code for line in frames[1].get_annotated_lines()] == [
+ line.code for line in frames[2].get_annotated_lines()
+ ]
+ assert u"höhö" in frames[1].sourcelines[3]
def test_filename_encoding(self, tmpdir, monkeypatch):
- moduledir = tmpdir.mkdir('föö')
- moduledir.join('bar.py').write('def foo():\n 1/0\n')
+ moduledir = tmpdir.mkdir("föö")
+ moduledir.join("bar.py").write("def foo():\n 1/0\n")
monkeypatch.syspath_prepend(str(moduledir))
import bar
@@ -274,7 +323,7 @@ class TestTraceback(object):
except ZeroDivisionError:
traceback = Traceback(*sys.exc_info())
- assert u'föö' in u'\n'.join(frame.render() for frame in traceback.frames)
+ assert u"föö" in u"\n".join(frame.render() for frame in traceback.frames)
def test_get_machine_id():
@@ -282,9 +331,10 @@ def test_get_machine_id():
assert isinstance(rv, bytes)
-@pytest.mark.parametrize('crash', (True, False))
+@pytest.mark.parametrize("crash", (True, False))
def test_basic(dev_server, crash):
- server = dev_server('''
+ server = dev_server(
+ """
from werkzeug.debug import DebuggedApplication
@DebuggedApplication
@@ -293,12 +343,14 @@ def test_basic(dev_server, crash):
1 / 0
start_response('200 OK', [('Content-Type', 'text/html')])
return [b'hello']
- '''.format(crash=crash))
+ """.format(
+ crash=crash
+ )
+ )
r = requests.get(server.url)
assert r.status_code == 500 if crash else 200
if crash:
- assert 'The debugger caught an exception in your WSGI application' \
- in r.text
+ assert "The debugger caught an exception in your WSGI application" in r.text
else:
- assert r.text == 'hello'
+ assert r.text == "hello"
diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py
index 7bb19a28..616b39c8 100644
--- a/tests/test_exceptions.py
+++ b/tests/test_exceptions.py
@@ -15,41 +15,44 @@
import pytest
from werkzeug import exceptions
+from werkzeug._compat import text_type
from werkzeug.datastructures import WWWAuthenticate
from werkzeug.wrappers import Response
-from werkzeug._compat import text_type
def test_proxy_exception():
- orig_resp = Response('Hello World')
+ orig_resp = Response("Hello World")
with pytest.raises(exceptions.HTTPException) as excinfo:
exceptions.abort(orig_resp)
resp = excinfo.value.get_response({})
assert resp is orig_resp
- assert resp.get_data() == b'Hello World'
-
-
-@pytest.mark.parametrize('test', [
- (exceptions.BadRequest, 400),
- (exceptions.Unauthorized, 401, 'Basic "test realm"'),
- (exceptions.Forbidden, 403),
- (exceptions.NotFound, 404),
- (exceptions.MethodNotAllowed, 405, ['GET', 'HEAD']),
- (exceptions.NotAcceptable, 406),
- (exceptions.RequestTimeout, 408),
- (exceptions.Gone, 410),
- (exceptions.LengthRequired, 411),
- (exceptions.PreconditionFailed, 412),
- (exceptions.RequestEntityTooLarge, 413),
- (exceptions.RequestURITooLarge, 414),
- (exceptions.UnsupportedMediaType, 415),
- (exceptions.UnprocessableEntity, 422),
- (exceptions.Locked, 423),
- (exceptions.InternalServerError, 500),
- (exceptions.NotImplemented, 501),
- (exceptions.BadGateway, 502),
- (exceptions.ServiceUnavailable, 503)
-])
+ assert resp.get_data() == b"Hello World"
+
+
+@pytest.mark.parametrize(
+ "test",
+ [
+ (exceptions.BadRequest, 400),
+ (exceptions.Unauthorized, 401, 'Basic "test realm"'),
+ (exceptions.Forbidden, 403),
+ (exceptions.NotFound, 404),
+ (exceptions.MethodNotAllowed, 405, ["GET", "HEAD"]),
+ (exceptions.NotAcceptable, 406),
+ (exceptions.RequestTimeout, 408),
+ (exceptions.Gone, 410),
+ (exceptions.LengthRequired, 411),
+ (exceptions.PreconditionFailed, 412),
+ (exceptions.RequestEntityTooLarge, 413),
+ (exceptions.RequestURITooLarge, 414),
+ (exceptions.UnsupportedMediaType, 415),
+ (exceptions.UnprocessableEntity, 422),
+ (exceptions.Locked, 423),
+ (exceptions.InternalServerError, 500),
+ (exceptions.NotImplemented, 501),
+ (exceptions.BadGateway, 502),
+ (exceptions.ServiceUnavailable, 503),
+ ],
+)
def test_aborter_general(test):
exc_type = test[0]
args = test[1:]
@@ -72,25 +75,26 @@ def test_aborter_custom():
def test_exception_repr():
exc = exceptions.NotFound()
assert text_type(exc) == (
- '404 Not Found: The requested URL was not found '
- 'on the server. If you entered the URL manually please check your '
- 'spelling and try again.')
+ "404 Not Found: The requested URL was not found on the server."
+ " If you entered the URL manually please check your spelling"
+ " and try again."
+ )
assert repr(exc) == "<NotFound '404: Not Found'>"
- exc = exceptions.NotFound('Not There')
- assert text_type(exc) == '404 Not Found: Not There'
+ exc = exceptions.NotFound("Not There")
+ assert text_type(exc) == "404 Not Found: Not There"
assert repr(exc) == "<NotFound '404: Not Found'>"
- exc = exceptions.HTTPException('An error message')
- assert text_type(exc) == '??? Unknown Error: An error message'
+ exc = exceptions.HTTPException("An error message")
+ assert text_type(exc) == "??? Unknown Error: An error message"
assert repr(exc) == "<HTTPException '???: Unknown Error'>"
def test_method_not_allowed_methods():
- exc = exceptions.MethodNotAllowed(['GET', 'HEAD', 'POST'])
+ exc = exceptions.MethodNotAllowed(["GET", "HEAD", "POST"])
h = dict(exc.get_headers({}))
- assert h['Allow'] == 'GET, HEAD, POST'
- assert 'The method is not allowed' in exc.get_description()
+ assert h["Allow"] == "GET, HEAD, POST"
+ assert "The method is not allowed" in exc.get_description()
def test_unauthorized_www_authenticate():
@@ -101,8 +105,8 @@ def test_unauthorized_www_authenticate():
exc = exceptions.Unauthorized(www_authenticate=basic)
h = dict(exc.get_headers({}))
- assert h['WWW-Authenticate'] == str(basic)
+ assert h["WWW-Authenticate"] == str(basic)
exc = exceptions.Unauthorized(www_authenticate=[digest, basic])
h = dict(exc.get_headers({}))
- assert h['WWW-Authenticate'] == ', '.join((str(digest), str(basic)))
+ assert h["WWW-Authenticate"] == ", ".join((str(digest), str(basic)))
diff --git a/tests/test_formparser.py b/tests/test_formparser.py
index 5f80ab5f..6d858386 100644
--- a/tests/test_formparser.py
+++ b/tests/test_formparser.py
@@ -10,106 +10,133 @@
"""
import csv
import io
-import pytest
-
-from os.path import join, dirname
+from os.path import dirname
+from os.path import join
-from tests import strict_eq
+import pytest
+from . import strict_eq
from werkzeug import formparser
-from werkzeug.test import create_environ, Client
-from werkzeug.wrappers import Request, Response
-from werkzeug.exceptions import RequestEntityTooLarge
+from werkzeug._compat import BytesIO
+from werkzeug._compat import PY2
from werkzeug.datastructures import MultiDict
-from werkzeug.formparser import parse_form_data, FormDataParser
-from werkzeug._compat import BytesIO, PY2
+from werkzeug.exceptions import RequestEntityTooLarge
+from werkzeug.formparser import FormDataParser
+from werkzeug.formparser import parse_form_data
+from werkzeug.test import Client
+from werkzeug.test import create_environ
+from werkzeug.wrappers import Request
+from werkzeug.wrappers import Response
@Request.application
def form_data_consumer(request):
- result_object = request.args['object']
- if result_object == 'text':
- return Response(repr(request.form['text']))
+ result_object = request.args["object"]
+ if result_object == "text":
+ return Response(repr(request.form["text"]))
f = request.files[result_object]
- return Response(b'\n'.join((
- repr(f.filename).encode('ascii'),
- repr(f.name).encode('ascii'),
- repr(f.content_type).encode('ascii'),
- f.stream.read()
- )))
+ return Response(
+ b"\n".join(
+ (
+ repr(f.filename).encode("ascii"),
+ repr(f.name).encode("ascii"),
+ repr(f.content_type).encode("ascii"),
+ f.stream.read(),
+ )
+ )
+ )
def get_contents(filename):
- with open(filename, 'rb') as f:
+ with open(filename, "rb") as f:
return f.read()
class TestFormParser(object):
-
def test_limiting(self):
- data = b'foo=Hello+World&bar=baz'
- req = Request.from_values(input_stream=BytesIO(data),
- content_length=len(data),
- content_type='application/x-www-form-urlencoded',
- method='POST')
+ data = b"foo=Hello+World&bar=baz"
+ req = Request.from_values(
+ input_stream=BytesIO(data),
+ content_length=len(data),
+ content_type="application/x-www-form-urlencoded",
+ method="POST",
+ )
req.max_content_length = 400
- strict_eq(req.form['foo'], u'Hello World')
+ strict_eq(req.form["foo"], u"Hello World")
- req = Request.from_values(input_stream=BytesIO(data),
- content_length=len(data),
- content_type='application/x-www-form-urlencoded',
- method='POST')
+ req = Request.from_values(
+ input_stream=BytesIO(data),
+ content_length=len(data),
+ content_type="application/x-www-form-urlencoded",
+ method="POST",
+ )
req.max_form_memory_size = 7
- pytest.raises(RequestEntityTooLarge, lambda: req.form['foo'])
+ pytest.raises(RequestEntityTooLarge, lambda: req.form["foo"])
- req = Request.from_values(input_stream=BytesIO(data),
- content_length=len(data),
- content_type='application/x-www-form-urlencoded',
- method='POST')
+ req = Request.from_values(
+ input_stream=BytesIO(data),
+ content_length=len(data),
+ content_type="application/x-www-form-urlencoded",
+ method="POST",
+ )
req.max_form_memory_size = 400
- strict_eq(req.form['foo'], u'Hello World')
-
- data = (b'--foo\r\nContent-Disposition: form-field; name=foo\r\n\r\n'
- b'Hello World\r\n'
- b'--foo\r\nContent-Disposition: form-field; name=bar\r\n\r\n'
- b'bar=baz\r\n--foo--')
- req = Request.from_values(input_stream=BytesIO(data),
- content_length=len(data),
- content_type='multipart/form-data; boundary=foo',
- method='POST')
+ strict_eq(req.form["foo"], u"Hello World")
+
+ data = (
+ b"--foo\r\nContent-Disposition: form-field; name=foo\r\n\r\n"
+ b"Hello World\r\n"
+ b"--foo\r\nContent-Disposition: form-field; name=bar\r\n\r\n"
+ b"bar=baz\r\n--foo--"
+ )
+ req = Request.from_values(
+ input_stream=BytesIO(data),
+ content_length=len(data),
+ content_type="multipart/form-data; boundary=foo",
+ method="POST",
+ )
req.max_content_length = 4
- pytest.raises(RequestEntityTooLarge, lambda: req.form['foo'])
+ pytest.raises(RequestEntityTooLarge, lambda: req.form["foo"])
- req = Request.from_values(input_stream=BytesIO(data),
- content_length=len(data),
- content_type='multipart/form-data; boundary=foo',
- method='POST')
+ req = Request.from_values(
+ input_stream=BytesIO(data),
+ content_length=len(data),
+ content_type="multipart/form-data; boundary=foo",
+ method="POST",
+ )
req.max_content_length = 400
- strict_eq(req.form['foo'], u'Hello World')
+ strict_eq(req.form["foo"], u"Hello World")
- req = Request.from_values(input_stream=BytesIO(data),
- content_length=len(data),
- content_type='multipart/form-data; boundary=foo',
- method='POST')
+ req = Request.from_values(
+ input_stream=BytesIO(data),
+ content_length=len(data),
+ content_type="multipart/form-data; boundary=foo",
+ method="POST",
+ )
req.max_form_memory_size = 7
- pytest.raises(RequestEntityTooLarge, lambda: req.form['foo'])
+ pytest.raises(RequestEntityTooLarge, lambda: req.form["foo"])
- req = Request.from_values(input_stream=BytesIO(data),
- content_length=len(data),
- content_type='multipart/form-data; boundary=foo',
- method='POST')
+ req = Request.from_values(
+ input_stream=BytesIO(data),
+ content_length=len(data),
+ content_type="multipart/form-data; boundary=foo",
+ method="POST",
+ )
req.max_form_memory_size = 400
- strict_eq(req.form['foo'], u'Hello World')
+ strict_eq(req.form["foo"], u"Hello World")
def test_missing_multipart_boundary(self):
- data = (b'--foo\r\nContent-Disposition: form-field; name=foo\r\n\r\n'
- b'Hello World\r\n'
- b'--foo\r\nContent-Disposition: form-field; name=bar\r\n\r\n'
- b'bar=baz\r\n--foo--')
- req = Request.from_values(input_stream=BytesIO(data),
- content_length=len(data),
- content_type='multipart/form-data',
- method='POST')
+ data = (
+ b"--foo\r\nContent-Disposition: form-field; name=foo\r\n\r\n"
+ b"Hello World\r\n"
+ b"--foo\r\nContent-Disposition: form-field; name=bar\r\n\r\n"
+ b"bar=baz\r\n--foo--"
+ )
+ req = Request.from_values(
+ input_stream=BytesIO(data),
+ content_length=len(data),
+ content_type="multipart/form-data",
+ method="POST",
+ )
assert req.form == {}
def test_parse_form_data_put_without_content(self):
@@ -119,24 +146,23 @@ class TestFormParser(object):
# containing an entity-body SHOULD include a Content-Type header field
# defining the media type of that body." In the case where either
# headers are omitted, parse_form_data should still work.
- env = create_environ('/foo', 'http://example.org/', method='PUT')
+ env = create_environ("/foo", "http://example.org/", method="PUT")
stream, form, files = formparser.parse_form_data(env)
- strict_eq(stream.read(), b'')
+ strict_eq(stream.read(), b"")
strict_eq(len(form), 0)
strict_eq(len(files), 0)
def test_parse_form_data_get_without_content(self):
- env = create_environ('/foo', 'http://example.org/', method='GET')
+ env = create_environ("/foo", "http://example.org/", method="GET")
stream, form, files = formparser.parse_form_data(env)
- strict_eq(stream.read(), b'')
+ strict_eq(stream.read(), b"")
strict_eq(len(form), 0)
strict_eq(len(files), 0)
@pytest.mark.parametrize(
- ("no_spooled", "size"),
- ((False, 100), (False, 3000), (True, 100), (True, 3000)),
+ ("no_spooled", "size"), ((False, 100), (False, 3000), (True, 100), (True, 3000))
)
def test_default_stream_factory(self, no_spooled, size, monkeypatch):
if no_spooled:
@@ -144,8 +170,7 @@ class TestFormParser(object):
data = b"a,b,c\n" * size
req = Request.from_values(
- data={"foo": (BytesIO(data), "test.txt")},
- method="POST"
+ data={"foo": (BytesIO(data), "test.txt")}, method="POST"
)
file_storage = req.files["foo"]
@@ -162,268 +187,338 @@ class TestFormParser(object):
file_storage.close()
def test_streaming_parse(self):
- data = b'x' * (1024 * 600)
+ data = b"x" * (1024 * 600)
class StreamMPP(formparser.MultiPartParser):
-
def parse(self, file, boundary, content_length):
- i = iter(self.parse_lines(file, boundary, content_length,
- cap_at_buffer=False))
+ i = iter(
+ self.parse_lines(
+ file, boundary, content_length, cap_at_buffer=False
+ )
+ )
one = next(i)
two = next(i)
- return self.cls(()), {'one': one, 'two': two}
+ return self.cls(()), {"one": one, "two": two}
class StreamFDP(formparser.FormDataParser):
-
- def _sf_parse_multipart(self, stream, mimetype,
- content_length, options):
+ def _sf_parse_multipart(self, stream, mimetype, content_length, options):
form, files = StreamMPP(
- self.stream_factory, self.charset, self.errors,
+ self.stream_factory,
+ self.charset,
+ self.errors,
max_form_memory_size=self.max_form_memory_size,
- cls=self.cls).parse(stream, options.get('boundary').encode('ascii'),
- content_length)
+ cls=self.cls,
+ ).parse(stream, options.get("boundary").encode("ascii"), content_length)
return stream, form, files
+
parse_functions = {}
parse_functions.update(formparser.FormDataParser.parse_functions)
- parse_functions['multipart/form-data'] = _sf_parse_multipart
+ parse_functions["multipart/form-data"] = _sf_parse_multipart
class StreamReq(Request):
form_data_parser_class = StreamFDP
- req = StreamReq.from_values(data={'foo': (BytesIO(data), 'test.txt')},
- method='POST')
- strict_eq('begin_file', req.files['one'][0])
- strict_eq(('foo', 'test.txt'), req.files['one'][1][1:])
- strict_eq('cont', req.files['two'][0])
- strict_eq(data, req.files['two'][1])
+
+ req = StreamReq.from_values(
+ data={"foo": (BytesIO(data), "test.txt")}, method="POST"
+ )
+ strict_eq("begin_file", req.files["one"][0])
+ strict_eq(("foo", "test.txt"), req.files["one"][1][1:])
+ strict_eq("cont", req.files["two"][0])
+ strict_eq(data, req.files["two"][1])
def test_parse_bad_content_type(self):
parser = FormDataParser()
- assert parser.parse('', 'bad-mime-type', 0) == \
- ('', MultiDict([]), MultiDict([]))
+ assert parser.parse("", "bad-mime-type", 0) == (
+ "",
+ MultiDict([]),
+ MultiDict([]),
+ )
def test_parse_from_environ(self):
parser = FormDataParser()
- stream, _, _ = parser.parse_from_environ({'wsgi.input': ''})
+ stream, _, _ = parser.parse_from_environ({"wsgi.input": ""})
assert stream is not None
class TestMultiPart(object):
-
def test_basic(self):
- resources = join(dirname(__file__), 'multipart')
+ resources = join(dirname(__file__), "multipart")
client = Client(form_data_consumer, Response)
repository = [
- ('firefox3-2png1txt', '---------------------------186454651713519341951581030105', [
- (u'anchor.png', 'file1', 'image/png', 'file1.png'),
- (u'application_edit.png', 'file2', 'image/png', 'file2.png')
- ], u'example text'),
- ('firefox3-2pnglongtext', '---------------------------14904044739787191031754711748', [
- (u'accept.png', 'file1', 'image/png', 'file1.png'),
- (u'add.png', 'file2', 'image/png', 'file2.png')
- ], u'--long text\r\n--with boundary\r\n--lookalikes--'),
- ('opera8-2png1txt', '----------zEO9jQKmLc2Cq88c23Dx19', [
- (u'arrow_branch.png', 'file1', 'image/png', 'file1.png'),
- (u'award_star_bronze_1.png', 'file2', 'image/png', 'file2.png')
- ], u'blafasel öäü'),
- ('webkit3-2png1txt', '----WebKitFormBoundaryjdSFhcARk8fyGNy6', [
- (u'gtk-apply.png', 'file1', 'image/png', 'file1.png'),
- (u'gtk-no.png', 'file2', 'image/png', 'file2.png')
- ], u'this is another text with ümläüts'),
- ('ie6-2png1txt', '---------------------------7d91b03a20128', [
- (u'file1.png', 'file1', 'image/x-png', 'file1.png'),
- (u'file2.png', 'file2', 'image/x-png', 'file2.png')
- ], u'ie6 sucks :-/')
+ (
+ "firefox3-2png1txt",
+ "---------------------------186454651713519341951581030105",
+ [
+ (u"anchor.png", "file1", "image/png", "file1.png"),
+ (u"application_edit.png", "file2", "image/png", "file2.png"),
+ ],
+ u"example text",
+ ),
+ (
+ "firefox3-2pnglongtext",
+ "---------------------------14904044739787191031754711748",
+ [
+ (u"accept.png", "file1", "image/png", "file1.png"),
+ (u"add.png", "file2", "image/png", "file2.png"),
+ ],
+ u"--long text\r\n--with boundary\r\n--lookalikes--",
+ ),
+ (
+ "opera8-2png1txt",
+ "----------zEO9jQKmLc2Cq88c23Dx19",
+ [
+ (u"arrow_branch.png", "file1", "image/png", "file1.png"),
+ (u"award_star_bronze_1.png", "file2", "image/png", "file2.png"),
+ ],
+ u"blafasel öäü",
+ ),
+ (
+ "webkit3-2png1txt",
+ "----WebKitFormBoundaryjdSFhcARk8fyGNy6",
+ [
+ (u"gtk-apply.png", "file1", "image/png", "file1.png"),
+ (u"gtk-no.png", "file2", "image/png", "file2.png"),
+ ],
+ u"this is another text with ümläüts",
+ ),
+ (
+ "ie6-2png1txt",
+ "---------------------------7d91b03a20128",
+ [
+ (u"file1.png", "file1", "image/x-png", "file1.png"),
+ (u"file2.png", "file2", "image/x-png", "file2.png"),
+ ],
+ u"ie6 sucks :-/",
+ ),
]
for name, boundary, files, text in repository:
folder = join(resources, name)
- data = get_contents(join(folder, 'request.txt'))
+ data = get_contents(join(folder, "request.http"))
for filename, field, content_type, fsname in files:
response = client.post(
- '/?object=' + field,
+ "/?object=" + field,
data=data,
content_type='multipart/form-data; boundary="%s"' % boundary,
- content_length=len(data))
- lines = response.get_data().split(b'\n', 3)
- strict_eq(lines[0], repr(filename).encode('ascii'))
- strict_eq(lines[1], repr(field).encode('ascii'))
- strict_eq(lines[2], repr(content_type).encode('ascii'))
+ content_length=len(data),
+ )
+ lines = response.get_data().split(b"\n", 3)
+ strict_eq(lines[0], repr(filename).encode("ascii"))
+ strict_eq(lines[1], repr(field).encode("ascii"))
+ strict_eq(lines[2], repr(content_type).encode("ascii"))
strict_eq(lines[3], get_contents(join(folder, fsname)))
response = client.post(
- '/?object=text',
+ "/?object=text",
data=data,
content_type='multipart/form-data; boundary="%s"' % boundary,
- content_length=len(data))
- strict_eq(response.get_data(), repr(text).encode('utf-8'))
+ content_length=len(data),
+ )
+ strict_eq(response.get_data(), repr(text).encode("utf-8"))
def test_ie7_unc_path(self):
client = Client(form_data_consumer, Response)
- data_file = join(dirname(__file__), 'multipart', 'ie7_full_path_request.txt')
+ data_file = join(dirname(__file__), "multipart", "ie7_full_path_request.http")
data = get_contents(data_file)
- boundary = '---------------------------7da36d1b4a0164'
+ boundary = "---------------------------7da36d1b4a0164"
response = client.post(
- '/?object=cb_file_upload_multiple',
+ "/?object=cb_file_upload_multiple",
data=data,
content_type='multipart/form-data; boundary="%s"' % boundary,
- content_length=len(data))
- lines = response.get_data().split(b'\n', 3)
- strict_eq(lines[0],
- repr(u'Sellersburg Town Council Meeting 02-22-2010doc.doc').encode('ascii'))
+ content_length=len(data),
+ )
+ lines = response.get_data().split(b"\n", 3)
+ strict_eq(
+ lines[0],
+ repr(u"Sellersburg Town Council Meeting 02-22-2010doc.doc").encode("ascii"),
+ )
def test_end_of_file(self):
# This test looks innocent but it was actually timeing out in
# the Werkzeug 0.5 release version (#394)
data = (
- b'--foo\r\n'
+ b"--foo\r\n"
b'Content-Disposition: form-data; name="test"; filename="test.txt"\r\n'
- b'Content-Type: text/plain\r\n\r\n'
- b'file contents and no end'
+ b"Content-Type: text/plain\r\n\r\n"
+ b"file contents and no end"
+ )
+ data = Request.from_values(
+ input_stream=BytesIO(data),
+ content_length=len(data),
+ content_type="multipart/form-data; boundary=foo",
+ method="POST",
)
- data = Request.from_values(input_stream=BytesIO(data),
- content_length=len(data),
- content_type='multipart/form-data; boundary=foo',
- method='POST')
assert not data.files
assert not data.form
def test_broken(self):
data = (
- '--foo\r\n'
+ "--foo\r\n"
'Content-Disposition: form-data; name="test"; filename="test.txt"\r\n'
- 'Content-Transfer-Encoding: base64\r\n'
- 'Content-Type: text/plain\r\n\r\n'
- 'broken base 64'
- '--foo--'
- )
- _, form, files = formparser.parse_form_data(create_environ(
- data=data, method='POST', content_type='multipart/form-data; boundary=foo'
- ))
+ "Content-Transfer-Encoding: base64\r\n"
+ "Content-Type: text/plain\r\n\r\n"
+ "broken base 64"
+ "--foo--"
+ )
+ _, form, files = formparser.parse_form_data(
+ create_environ(
+ data=data,
+ method="POST",
+ content_type="multipart/form-data; boundary=foo",
+ )
+ )
assert not files
assert not form
- pytest.raises(ValueError, formparser.parse_form_data,
- create_environ(data=data, method='POST',
- content_type='multipart/form-data; boundary=foo'),
- silent=False)
+ pytest.raises(
+ ValueError,
+ formparser.parse_form_data,
+ create_environ(
+ data=data,
+ method="POST",
+ content_type="multipart/form-data; boundary=foo",
+ ),
+ silent=False,
+ )
def test_file_no_content_type(self):
data = (
- b'--foo\r\n'
+ b"--foo\r\n"
b'Content-Disposition: form-data; name="test"; filename="test.txt"\r\n\r\n'
- b'file contents\r\n--foo--'
+ b"file contents\r\n--foo--"
)
- data = Request.from_values(input_stream=BytesIO(data),
- content_length=len(data),
- content_type='multipart/form-data; boundary=foo',
- method='POST')
- assert data.files['test'].filename == 'test.txt'
- strict_eq(data.files['test'].read(), b'file contents')
+ data = Request.from_values(
+ input_stream=BytesIO(data),
+ content_length=len(data),
+ content_type="multipart/form-data; boundary=foo",
+ method="POST",
+ )
+ assert data.files["test"].filename == "test.txt"
+ strict_eq(data.files["test"].read(), b"file contents")
def test_extra_newline(self):
# this test looks innocent but it was actually timeing out in
# the Werkzeug 0.5 release version (#394)
data = (
- b'\r\n\r\n--foo\r\n'
+ b"\r\n\r\n--foo\r\n"
b'Content-Disposition: form-data; name="foo"\r\n\r\n'
- b'a string\r\n'
- b'--foo--'
+ b"a string\r\n"
+ b"--foo--"
+ )
+ data = Request.from_values(
+ input_stream=BytesIO(data),
+ content_length=len(data),
+ content_type="multipart/form-data; boundary=foo",
+ method="POST",
)
- data = Request.from_values(input_stream=BytesIO(data),
- content_length=len(data),
- content_type='multipart/form-data; boundary=foo',
- method='POST')
assert not data.files
- strict_eq(data.form['foo'], u'a string')
+ strict_eq(data.form["foo"], u"a string")
def test_headers(self):
- data = (b'--foo\r\n'
- b'Content-Disposition: form-data; name="foo"; filename="foo.txt"\r\n'
- b'X-Custom-Header: blah\r\n'
- b'Content-Type: text/plain; charset=utf-8\r\n\r\n'
- b'file contents, just the contents\r\n'
- b'--foo--')
- req = Request.from_values(input_stream=BytesIO(data),
- content_length=len(data),
- content_type='multipart/form-data; boundary=foo',
- method='POST')
- foo = req.files['foo']
- strict_eq(foo.mimetype, 'text/plain')
- strict_eq(foo.mimetype_params, {'charset': 'utf-8'})
- strict_eq(foo.headers['content-type'], foo.content_type)
- strict_eq(foo.content_type, 'text/plain; charset=utf-8')
- strict_eq(foo.headers['x-custom-header'], 'blah')
+ data = (
+ b"--foo\r\n"
+ b'Content-Disposition: form-data; name="foo"; filename="foo.txt"\r\n'
+ b"X-Custom-Header: blah\r\n"
+ b"Content-Type: text/plain; charset=utf-8\r\n\r\n"
+ b"file contents, just the contents\r\n"
+ b"--foo--"
+ )
+ req = Request.from_values(
+ input_stream=BytesIO(data),
+ content_length=len(data),
+ content_type="multipart/form-data; boundary=foo",
+ method="POST",
+ )
+ foo = req.files["foo"]
+ strict_eq(foo.mimetype, "text/plain")
+ strict_eq(foo.mimetype_params, {"charset": "utf-8"})
+ strict_eq(foo.headers["content-type"], foo.content_type)
+ strict_eq(foo.content_type, "text/plain; charset=utf-8")
+ strict_eq(foo.headers["x-custom-header"], "blah")
def test_nonstandard_line_endings(self):
- for nl in b'\n', b'\r', b'\r\n':
- data = nl.join((
- b'--foo',
- b'Content-Disposition: form-data; name=foo',
- b'',
- b'this is just bar',
- b'--foo',
- b'Content-Disposition: form-data; name=bar',
- b'',
- b'blafasel',
- b'--foo--'
- ))
- req = Request.from_values(input_stream=BytesIO(data),
- content_length=len(data),
- content_type='multipart/form-data; '
- 'boundary=foo', method='POST')
- strict_eq(req.form['foo'], u'this is just bar')
- strict_eq(req.form['bar'], u'blafasel')
+ for nl in b"\n", b"\r", b"\r\n":
+ data = nl.join(
+ (
+ b"--foo",
+ b"Content-Disposition: form-data; name=foo",
+ b"",
+ b"this is just bar",
+ b"--foo",
+ b"Content-Disposition: form-data; name=bar",
+ b"",
+ b"blafasel",
+ b"--foo--",
+ )
+ )
+ req = Request.from_values(
+ input_stream=BytesIO(data),
+ content_length=len(data),
+ content_type="multipart/form-data; boundary=foo",
+ method="POST",
+ )
+ strict_eq(req.form["foo"], u"this is just bar")
+ strict_eq(req.form["bar"], u"blafasel")
def test_failures(self):
def parse_multipart(stream, boundary, content_length):
parser = formparser.MultiPartParser(content_length)
return parser.parse(stream, boundary, content_length)
- pytest.raises(ValueError, parse_multipart, BytesIO(), b'broken ', 0)
- data = b'--foo\r\n\r\nHello World\r\n--foo--'
- pytest.raises(ValueError, parse_multipart, BytesIO(data), b'foo', len(data))
+ pytest.raises(ValueError, parse_multipart, BytesIO(), b"broken ", 0)
- data = b'--foo\r\nContent-Disposition: form-field; name=foo\r\n' \
- b'Content-Transfer-Encoding: base64\r\n\r\nHello World\r\n--foo--'
- pytest.raises(ValueError, parse_multipart, BytesIO(data), b'foo', len(data))
+ data = b"--foo\r\n\r\nHello World\r\n--foo--"
+ pytest.raises(ValueError, parse_multipart, BytesIO(data), b"foo", len(data))
- data = b'--foo\r\nContent-Disposition: form-field; name=foo\r\n\r\nHello World\r\n'
- pytest.raises(ValueError, parse_multipart, BytesIO(data), b'foo', len(data))
+ data = (
+ b"--foo\r\nContent-Disposition: form-field; name=foo\r\n"
+ b"Content-Transfer-Encoding: base64\r\n\r\nHello World\r\n--foo--"
+ )
+ pytest.raises(ValueError, parse_multipart, BytesIO(data), b"foo", len(data))
- x = formparser.parse_multipart_headers(['foo: bar\r\n', ' x test\r\n'])
- strict_eq(x['foo'], 'bar\n x test')
- pytest.raises(ValueError, formparser.parse_multipart_headers,
- ['foo: bar\r\n', ' x test'])
+ data = (
+ b"--foo\r\nContent-Disposition: form-field; name=foo\r\n\r\nHello World\r\n"
+ )
+ pytest.raises(ValueError, parse_multipart, BytesIO(data), b"foo", len(data))
+
+ x = formparser.parse_multipart_headers(["foo: bar\r\n", " x test\r\n"])
+ strict_eq(x["foo"], "bar\n x test")
+ pytest.raises(
+ ValueError, formparser.parse_multipart_headers, ["foo: bar\r\n", " x test"]
+ )
def test_bad_newline_bad_newline_assumption(self):
class ISORequest(Request):
- charset = 'latin1'
- contents = b'U2vlbmUgbORu'
- data = b'--foo\r\nContent-Disposition: form-data; name="test"\r\n' \
- b'Content-Transfer-Encoding: base64\r\n\r\n' + \
- contents + b'\r\n--foo--'
- req = ISORequest.from_values(input_stream=BytesIO(data),
- content_length=len(data),
- content_type='multipart/form-data; boundary=foo',
- method='POST')
- strict_eq(req.form['test'], u'Sk\xe5ne l\xe4n')
+ charset = "latin1"
+
+ contents = b"U2vlbmUgbORu"
+ data = (
+ b'--foo\r\nContent-Disposition: form-data; name="test"\r\n'
+ b"Content-Transfer-Encoding: base64\r\n\r\n" + contents + b"\r\n--foo--"
+ )
+ req = ISORequest.from_values(
+ input_stream=BytesIO(data),
+ content_length=len(data),
+ content_type="multipart/form-data; boundary=foo",
+ method="POST",
+ )
+ strict_eq(req.form["test"], u"Sk\xe5ne l\xe4n")
def test_empty_multipart(self):
environ = {}
- data = b'--boundary--'
- environ['REQUEST_METHOD'] = 'POST'
- environ['CONTENT_TYPE'] = 'multipart/form-data; boundary=boundary'
- environ['CONTENT_LENGTH'] = str(len(data))
- environ['wsgi.input'] = BytesIO(data)
+ data = b"--boundary--"
+ environ["REQUEST_METHOD"] = "POST"
+ environ["CONTENT_TYPE"] = "multipart/form-data; boundary=boundary"
+ environ["CONTENT_LENGTH"] = str(len(data))
+ environ["wsgi.input"] = BytesIO(data)
stream, form, files = parse_form_data(environ, silent=False)
rv = stream.read()
- assert rv == b''
+ assert rv == b""
assert form == MultiDict()
assert files == MultiDict()
class TestMultiPartParser(object):
-
def test_constructor_not_pass_stream_factory_and_cls(self):
parser = formparser.MultiPartParser()
@@ -443,7 +538,7 @@ class TestMultiPartParser(object):
data = (
b"--foo\r\n"
b"Content-Type: text/plain; charset=utf-8\r\n"
- b'Content-Disposition: form-data; name=rfc2231;\r\n'
+ b"Content-Disposition: form-data; name=rfc2231;\r\n"
b" filename*0*=ascii''a%20b%20;\r\n"
b" filename*1*=c%20d%20;\r\n"
b' filename*2="e f.txt"\r\n\r\n'
@@ -453,25 +548,24 @@ class TestMultiPartParser(object):
input_stream=BytesIO(data),
content_length=len(data),
content_type="multipart/form-data; boundary=foo",
- method="POST"
+ method="POST",
)
assert request.files["rfc2231"].filename == "a b c d e f.txt"
assert request.files["rfc2231"].read() == b"file contents"
class TestInternalFunctions(object):
-
def test_line_parser(self):
- assert formparser._line_parse('foo') == ('foo', False)
- assert formparser._line_parse('foo\r\n') == ('foo', True)
- assert formparser._line_parse('foo\r') == ('foo', True)
- assert formparser._line_parse('foo\n') == ('foo', True)
+ assert formparser._line_parse("foo") == ("foo", False)
+ assert formparser._line_parse("foo\r\n") == ("foo", True)
+ assert formparser._line_parse("foo\r") == ("foo", True)
+ assert formparser._line_parse("foo\n") == ("foo", True)
def test_find_terminator(self):
- lineiter = iter(b'\n\n\nfoo\nbar\nbaz'.splitlines(True))
+ lineiter = iter(b"\n\n\nfoo\nbar\nbaz".splitlines(True))
find_terminator = formparser.MultiPartParser()._find_terminator
line = find_terminator(lineiter)
- assert line == b'foo'
- assert list(lineiter) == [b'bar\n', b'baz']
- assert find_terminator([]) == b''
- assert find_terminator([b'']) == b''
+ assert line == b"foo"
+ assert list(lineiter) == [b"bar\n", b"baz"]
+ assert find_terminator([]) == b""
+ assert find_terminator([b""]) == b""
diff --git a/tests/test_http.py b/tests/test_http.py
index b261a605..c6e8309a 100644
--- a/tests/test_http.py
+++ b/tests/test_http.py
@@ -8,199 +8,222 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-import pytest
-
from datetime import datetime
-from tests import strict_eq
-from werkzeug._compat import itervalues, wsgi_encoding_dance
+import pytest
-from werkzeug import http, datastructures
+from . import strict_eq
+from werkzeug import datastructures
+from werkzeug import http
+from werkzeug._compat import itervalues
+from werkzeug._compat import wsgi_encoding_dance
from werkzeug.test import create_environ
class TestHTTPUtility(object):
-
def test_accept(self):
- a = http.parse_accept_header('en-us,ru;q=0.5')
- assert list(itervalues(a)) == ['en-us', 'ru']
- assert a.best == 'en-us'
- assert a.find('ru') == 1
- pytest.raises(ValueError, a.index, 'de')
- assert a.to_header() == 'en-us,ru;q=0.5'
+ a = http.parse_accept_header("en-us,ru;q=0.5")
+ assert list(itervalues(a)) == ["en-us", "ru"]
+ assert a.best == "en-us"
+ assert a.find("ru") == 1
+ pytest.raises(ValueError, a.index, "de")
+ assert a.to_header() == "en-us,ru;q=0.5"
def test_mime_accept(self):
- a = http.parse_accept_header('text/xml,application/xml,'
- 'application/xhtml+xml,'
- 'application/foo;quiet=no; bar=baz;q=0.6,'
- 'text/html;q=0.9,text/plain;q=0.8,'
- 'image/png,*/*;q=0.5',
- datastructures.MIMEAccept)
- pytest.raises(ValueError, lambda: a['missing'])
- assert a['image/png'] == 1
- assert a['text/plain'] == 0.8
- assert a['foo/bar'] == 0.5
- assert a['application/foo;quiet=no; bar=baz'] == 0.6
- assert a[a.find('foo/bar')] == ('*/*', 0.5)
+ a = http.parse_accept_header(
+ "text/xml,application/xml,"
+ "application/xhtml+xml,"
+ "application/foo;quiet=no; bar=baz;q=0.6,"
+ "text/html;q=0.9,text/plain;q=0.8,"
+ "image/png,*/*;q=0.5",
+ datastructures.MIMEAccept,
+ )
+ pytest.raises(ValueError, lambda: a["missing"])
+ assert a["image/png"] == 1
+ assert a["text/plain"] == 0.8
+ assert a["foo/bar"] == 0.5
+ assert a["application/foo;quiet=no; bar=baz"] == 0.6
+ assert a[a.find("foo/bar")] == ("*/*", 0.5)
def test_accept_matches(self):
- a = http.parse_accept_header('text/xml,application/xml,application/xhtml+xml,'
- 'text/html;q=0.9,text/plain;q=0.8,'
- 'image/png', datastructures.MIMEAccept)
- assert a.best_match(['text/html', 'application/xhtml+xml']) == \
- 'application/xhtml+xml'
- assert a.best_match(['text/html']) == 'text/html'
- assert a.best_match(['foo/bar']) is None
- assert a.best_match(['foo/bar', 'bar/foo'], default='foo/bar') == 'foo/bar'
- assert a.best_match(['application/xml', 'text/xml']) == 'application/xml'
+ a = http.parse_accept_header(
+ "text/xml,application/xml,application/xhtml+xml,"
+ "text/html;q=0.9,text/plain;q=0.8,"
+ "image/png",
+ datastructures.MIMEAccept,
+ )
+ assert (
+ a.best_match(["text/html", "application/xhtml+xml"])
+ == "application/xhtml+xml"
+ )
+ assert a.best_match(["text/html"]) == "text/html"
+ assert a.best_match(["foo/bar"]) is None
+ assert a.best_match(["foo/bar", "bar/foo"], default="foo/bar") == "foo/bar"
+ assert a.best_match(["application/xml", "text/xml"]) == "application/xml"
def test_charset_accept(self):
- a = http.parse_accept_header('ISO-8859-1,utf-8;q=0.7,*;q=0.7',
- datastructures.CharsetAccept)
- assert a['iso-8859-1'] == a['iso8859-1']
- assert a['iso-8859-1'] == 1
- assert a['UTF8'] == 0.7
- assert a['ebcdic'] == 0.7
+ a = http.parse_accept_header(
+ "ISO-8859-1,utf-8;q=0.7,*;q=0.7", datastructures.CharsetAccept
+ )
+ assert a["iso-8859-1"] == a["iso8859-1"]
+ assert a["iso-8859-1"] == 1
+ assert a["UTF8"] == 0.7
+ assert a["ebcdic"] == 0.7
def test_language_accept(self):
- a = http.parse_accept_header('de-AT,de;q=0.8,en;q=0.5',
- datastructures.LanguageAccept)
- assert a.best == 'de-AT'
- assert 'de_AT' in a
- assert 'en' in a
- assert a['de-at'] == 1
- assert a['en'] == 0.5
+ a = http.parse_accept_header(
+ "de-AT,de;q=0.8,en;q=0.5", datastructures.LanguageAccept
+ )
+ assert a.best == "de-AT"
+ assert "de_AT" in a
+ assert "en" in a
+ assert a["de-at"] == 1
+ assert a["en"] == 0.5
def test_set_header(self):
hs = http.parse_set_header('foo, Bar, "Blah baz", Hehe')
- assert 'blah baz' in hs
- assert 'foobar' not in hs
- assert 'foo' in hs
- assert list(hs) == ['foo', 'Bar', 'Blah baz', 'Hehe']
- hs.add('Foo')
+ assert "blah baz" in hs
+ assert "foobar" not in hs
+ assert "foo" in hs
+ assert list(hs) == ["foo", "Bar", "Blah baz", "Hehe"]
+ hs.add("Foo")
assert hs.to_header() == 'foo, Bar, "Blah baz", Hehe'
def test_list_header(self):
- hl = http.parse_list_header('foo baz, blah')
- assert hl == ['foo baz', 'blah']
+ hl = http.parse_list_header("foo baz, blah")
+ assert hl == ["foo baz", "blah"]
def test_dict_header(self):
d = http.parse_dict_header('foo="bar baz", blah=42')
- assert d == {'foo': 'bar baz', 'blah': '42'}
+ assert d == {"foo": "bar baz", "blah": "42"}
def test_cache_control_header(self):
- cc = http.parse_cache_control_header('max-age=0, no-cache')
+ cc = http.parse_cache_control_header("max-age=0, no-cache")
assert cc.max_age == 0
assert cc.no_cache
- cc = http.parse_cache_control_header('private, community="UCI"', None,
- datastructures.ResponseCacheControl)
+ cc = http.parse_cache_control_header(
+ 'private, community="UCI"', None, datastructures.ResponseCacheControl
+ )
assert cc.private
- assert cc['community'] == 'UCI'
+ assert cc["community"] == "UCI"
c = datastructures.ResponseCacheControl()
assert c.no_cache is None
assert c.private is None
c.no_cache = True
- assert c.no_cache == '*'
+ assert c.no_cache == "*"
c.private = True
- assert c.private == '*'
+ assert c.private == "*"
del c.private
assert c.private is None
- assert c.to_header() == 'no-cache'
+ assert c.to_header() == "no-cache"
def test_authorization_header(self):
- a = http.parse_authorization_header('Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==')
- assert a.type == 'basic'
- assert a.username == u'Aladdin'
- assert a.password == u'open sesame'
-
- a = http.parse_authorization_header('Basic 0YDRg9GB0YHQutC40IE60JHRg9C60LLRiw==')
- assert a.type == 'basic'
- assert a.username == u'русскиЁ'
- assert a.password == u'Буквы'
-
- a = http.parse_authorization_header('Basic 5pmu6YCa6K+dOuS4reaWhw==')
- assert a.type == 'basic'
- assert a.username == u'普通话'
- assert a.password == u'中文'
-
- a = http.parse_authorization_header('''Digest username="Mufasa",
- realm="testrealm@host.invalid",
- nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093",
- uri="/dir/index.html",
- qop=auth,
- nc=00000001,
- cnonce="0a4f113b",
- response="6629fae49393a05397450978507c4ef1",
- opaque="5ccc069c403ebaf9f0171e9517f40e41"''')
- assert a.type == 'digest'
- assert a.username == 'Mufasa'
- assert a.realm == 'testrealm@host.invalid'
- assert a.nonce == 'dcd98b7102dd2f0e8b11d0f600bfb0c093'
- assert a.uri == '/dir/index.html'
- assert a.qop == 'auth'
- assert a.nc == '00000001'
- assert a.cnonce == '0a4f113b'
- assert a.response == '6629fae49393a05397450978507c4ef1'
- assert a.opaque == '5ccc069c403ebaf9f0171e9517f40e41'
-
- a = http.parse_authorization_header('''Digest username="Mufasa",
- realm="testrealm@host.invalid",
- nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093",
- uri="/dir/index.html",
- response="e257afa1414a3340d93d30955171dd0e",
- opaque="5ccc069c403ebaf9f0171e9517f40e41"''')
- assert a.type == 'digest'
- assert a.username == 'Mufasa'
- assert a.realm == 'testrealm@host.invalid'
- assert a.nonce == 'dcd98b7102dd2f0e8b11d0f600bfb0c093'
- assert a.uri == '/dir/index.html'
- assert a.response == 'e257afa1414a3340d93d30955171dd0e'
- assert a.opaque == '5ccc069c403ebaf9f0171e9517f40e41'
-
- assert http.parse_authorization_header('') is None
+ a = http.parse_authorization_header("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")
+ assert a.type == "basic"
+ assert a.username == u"Aladdin"
+ assert a.password == u"open sesame"
+
+ a = http.parse_authorization_header(
+ "Basic 0YDRg9GB0YHQutC40IE60JHRg9C60LLRiw=="
+ )
+ assert a.type == "basic"
+ assert a.username == u"русскиЁ"
+ assert a.password == u"Буквы"
+
+ a = http.parse_authorization_header("Basic 5pmu6YCa6K+dOuS4reaWhw==")
+ assert a.type == "basic"
+ assert a.username == u"普通话"
+ assert a.password == u"中文"
+
+ a = http.parse_authorization_header(
+ '''Digest username="Mufasa",
+ realm="testrealm@host.invalid",
+ nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093",
+ uri="/dir/index.html",
+ qop=auth,
+ nc=00000001,
+ cnonce="0a4f113b",
+ response="6629fae49393a05397450978507c4ef1",
+ opaque="5ccc069c403ebaf9f0171e9517f40e41"'''
+ )
+ assert a.type == "digest"
+ assert a.username == "Mufasa"
+ assert a.realm == "testrealm@host.invalid"
+ assert a.nonce == "dcd98b7102dd2f0e8b11d0f600bfb0c093"
+ assert a.uri == "/dir/index.html"
+ assert a.qop == "auth"
+ assert a.nc == "00000001"
+ assert a.cnonce == "0a4f113b"
+ assert a.response == "6629fae49393a05397450978507c4ef1"
+ assert a.opaque == "5ccc069c403ebaf9f0171e9517f40e41"
+
+ a = http.parse_authorization_header(
+ '''Digest username="Mufasa",
+ realm="testrealm@host.invalid",
+ nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093",
+ uri="/dir/index.html",
+ response="e257afa1414a3340d93d30955171dd0e",
+ opaque="5ccc069c403ebaf9f0171e9517f40e41"'''
+ )
+ assert a.type == "digest"
+ assert a.username == "Mufasa"
+ assert a.realm == "testrealm@host.invalid"
+ assert a.nonce == "dcd98b7102dd2f0e8b11d0f600bfb0c093"
+ assert a.uri == "/dir/index.html"
+ assert a.response == "e257afa1414a3340d93d30955171dd0e"
+ assert a.opaque == "5ccc069c403ebaf9f0171e9517f40e41"
+
+ assert http.parse_authorization_header("") is None
assert http.parse_authorization_header(None) is None
- assert http.parse_authorization_header('foo') is None
+ assert http.parse_authorization_header("foo") is None
def test_www_authenticate_header(self):
wa = http.parse_www_authenticate_header('Basic realm="WallyWorld"')
- assert wa.type == 'basic'
- assert wa.realm == 'WallyWorld'
- wa.realm = 'Foo Bar'
+ assert wa.type == "basic"
+ assert wa.realm == "WallyWorld"
+ wa.realm = "Foo Bar"
assert wa.to_header() == 'Basic realm="Foo Bar"'
- wa = http.parse_www_authenticate_header('''Digest
- realm="testrealm@host.com",
- qop="auth,auth-int",
- nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093",
- opaque="5ccc069c403ebaf9f0171e9517f40e41"''')
- assert wa.type == 'digest'
- assert wa.realm == 'testrealm@host.com'
- assert 'auth' in wa.qop
- assert 'auth-int' in wa.qop
- assert wa.nonce == 'dcd98b7102dd2f0e8b11d0f600bfb0c093'
- assert wa.opaque == '5ccc069c403ebaf9f0171e9517f40e41'
+ wa = http.parse_www_authenticate_header(
+ '''Digest
+ realm="testrealm@host.com",
+ qop="auth,auth-int",
+ nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093",
+ opaque="5ccc069c403ebaf9f0171e9517f40e41"'''
+ )
+ assert wa.type == "digest"
+ assert wa.realm == "testrealm@host.com"
+ assert "auth" in wa.qop
+ assert "auth-int" in wa.qop
+ assert wa.nonce == "dcd98b7102dd2f0e8b11d0f600bfb0c093"
+ assert wa.opaque == "5ccc069c403ebaf9f0171e9517f40e41"
- wa = http.parse_www_authenticate_header('broken')
- assert wa.type == 'broken'
+ wa = http.parse_www_authenticate_header("broken")
+ assert wa.type == "broken"
- assert not http.parse_www_authenticate_header('').type
- assert not http.parse_www_authenticate_header('')
+ assert not http.parse_www_authenticate_header("").type
+ assert not http.parse_www_authenticate_header("")
def test_etags(self):
- assert http.quote_etag('foo') == '"foo"'
- assert http.quote_etag('foo', True) == 'W/"foo"'
- assert http.unquote_etag('"foo"') == ('foo', False)
- assert http.unquote_etag('W/"foo"') == ('foo', True)
+ assert http.quote_etag("foo") == '"foo"'
+ assert http.quote_etag("foo", True) == 'W/"foo"'
+ assert http.unquote_etag('"foo"') == ("foo", False)
+ assert http.unquote_etag('W/"foo"') == ("foo", True)
es = http.parse_etags('"foo", "bar", W/"baz", blar')
- assert sorted(es) == ['bar', 'blar', 'foo']
- assert 'foo' in es
- assert 'baz' not in es
- assert es.contains_weak('baz')
- assert 'blar' in es
+ assert sorted(es) == ["bar", "blar", "foo"]
+ assert "foo" in es
+ assert "baz" not in es
+ assert es.contains_weak("baz")
+ assert "blar" in es
assert es.contains_raw('W/"baz"')
assert es.contains_raw('"foo"')
- assert sorted(es.to_header().split(', ')) == ['"bar"', '"blar"', '"foo"', 'W/"baz"']
+ assert sorted(es.to_header().split(", ")) == [
+ '"bar"',
+ '"blar"',
+ '"foo"',
+ 'W/"baz"',
+ ]
def test_etags_nonzero(self):
etags = http.parse_etags('W/"foo"')
@@ -208,392 +231,433 @@ class TestHTTPUtility(object):
assert etags.contains_raw('W/"foo"')
def test_parse_date(self):
- assert http.parse_date('Sun, 06 Nov 1994 08:49:37 GMT ') == datetime(
- 1994, 11, 6, 8, 49, 37)
- assert http.parse_date('Sunday, 06-Nov-94 08:49:37 GMT') == datetime(1994, 11, 6, 8, 49, 37)
- assert http.parse_date(' Sun Nov 6 08:49:37 1994') == datetime(1994, 11, 6, 8, 49, 37)
- assert http.parse_date('foo') is None
+ assert http.parse_date("Sun, 06 Nov 1994 08:49:37 GMT ") == datetime(
+ 1994, 11, 6, 8, 49, 37
+ )
+ assert http.parse_date("Sunday, 06-Nov-94 08:49:37 GMT") == datetime(
+ 1994, 11, 6, 8, 49, 37
+ )
+ assert http.parse_date(" Sun Nov 6 08:49:37 1994") == datetime(
+ 1994, 11, 6, 8, 49, 37
+ )
+ assert http.parse_date("foo") is None
def test_parse_date_overflows(self):
- assert http.parse_date(' Sun 02 Feb 1343 08:49:37 GMT') == datetime(1343, 2, 2, 8, 49, 37)
- assert http.parse_date('Thu, 01 Jan 1970 00:00:00 GMT') == datetime(1970, 1, 1, 0, 0)
- assert http.parse_date('Thu, 33 Jan 1970 00:00:00 GMT') is None
+ assert http.parse_date(" Sun 02 Feb 1343 08:49:37 GMT") == datetime(
+ 1343, 2, 2, 8, 49, 37
+ )
+ assert http.parse_date("Thu, 01 Jan 1970 00:00:00 GMT") == datetime(
+ 1970, 1, 1, 0, 0
+ )
+ assert http.parse_date("Thu, 33 Jan 1970 00:00:00 GMT") is None
def test_remove_entity_headers(self):
now = http.http_date()
- headers1 = [('Date', now), ('Content-Type', 'text/html'), ('Content-Length', '0')]
+ headers1 = [
+ ("Date", now),
+ ("Content-Type", "text/html"),
+ ("Content-Length", "0"),
+ ]
headers2 = datastructures.Headers(headers1)
http.remove_entity_headers(headers1)
- assert headers1 == [('Date', now)]
+ assert headers1 == [("Date", now)]
http.remove_entity_headers(headers2)
- assert headers2 == datastructures.Headers([(u'Date', now)])
+ assert headers2 == datastructures.Headers([(u"Date", now)])
def test_remove_hop_by_hop_headers(self):
- headers1 = [('Connection', 'closed'), ('Foo', 'bar'),
- ('Keep-Alive', 'wtf')]
+ headers1 = [("Connection", "closed"), ("Foo", "bar"), ("Keep-Alive", "wtf")]
headers2 = datastructures.Headers(headers1)
http.remove_hop_by_hop_headers(headers1)
- assert headers1 == [('Foo', 'bar')]
+ assert headers1 == [("Foo", "bar")]
http.remove_hop_by_hop_headers(headers2)
- assert headers2 == datastructures.Headers([('Foo', 'bar')])
+ assert headers2 == datastructures.Headers([("Foo", "bar")])
def test_parse_options_header(self):
- assert http.parse_options_header(None) == \
- ('', {})
- assert http.parse_options_header("") == \
- ('', {})
- assert http.parse_options_header(r'something; foo="other\"thing"') == \
- ('something', {'foo': 'other"thing'})
- assert http.parse_options_header(r'something; foo="other\"thing"; meh=42') == \
- ('something', {'foo': 'other"thing', 'meh': '42'})
- assert http.parse_options_header(r'something; foo="other\"thing"; meh=42; bleh') == \
- ('something', {'foo': 'other"thing', 'meh': '42', 'bleh': None})
- assert http.parse_options_header('something; foo="other;thing"; meh=42; bleh') == \
- ('something', {'foo': 'other;thing', 'meh': '42', 'bleh': None})
- assert http.parse_options_header('something; foo="otherthing"; meh=; bleh') == \
- ('something', {'foo': 'otherthing', 'meh': None, 'bleh': None})
+ assert http.parse_options_header(None) == ("", {})
+ assert http.parse_options_header("") == ("", {})
+ assert http.parse_options_header(r'something; foo="other\"thing"') == (
+ "something",
+ {"foo": 'other"thing'},
+ )
+ assert http.parse_options_header(r'something; foo="other\"thing"; meh=42') == (
+ "something",
+ {"foo": 'other"thing', "meh": "42"},
+ )
+ assert http.parse_options_header(
+ r'something; foo="other\"thing"; meh=42; bleh'
+ ) == ("something", {"foo": 'other"thing', "meh": "42", "bleh": None})
+ assert http.parse_options_header(
+ 'something; foo="other;thing"; meh=42; bleh'
+ ) == ("something", {"foo": "other;thing", "meh": "42", "bleh": None})
+ assert http.parse_options_header('something; foo="otherthing"; meh=; bleh') == (
+ "something",
+ {"foo": "otherthing", "meh": None, "bleh": None},
+ )
# Issue #404
- assert http.parse_options_header('multipart/form-data; name="foo bar"; '
- 'filename="bar foo"') == \
- ('multipart/form-data', {'name': 'foo bar', 'filename': 'bar foo'})
+ assert http.parse_options_header(
+ 'multipart/form-data; name="foo bar"; ' 'filename="bar foo"'
+ ) == ("multipart/form-data", {"name": "foo bar", "filename": "bar foo"})
# Examples from RFC
- assert http.parse_options_header('audio/*; q=0.2, audio/basic') == \
- ('audio/*', {'q': '0.2'})
- assert http.parse_options_header('audio/*; q=0.2, audio/basic', multiple=True) == \
- ('audio/*', {'q': '0.2'}, "audio/basic", {})
+ assert http.parse_options_header("audio/*; q=0.2, audio/basic") == (
+ "audio/*",
+ {"q": "0.2"},
+ )
+ assert http.parse_options_header(
+ "audio/*; q=0.2, audio/basic", multiple=True
+ ) == ("audio/*", {"q": "0.2"}, "audio/basic", {})
+ assert http.parse_options_header(
+ "text/plain; q=0.5, text/html\n text/x-dvi; q=0.8, text/x-c",
+ multiple=True,
+ ) == (
+ "text/plain",
+ {"q": "0.5"},
+ "text/html",
+ {},
+ "text/x-dvi",
+ {"q": "0.8"},
+ "text/x-c",
+ {},
+ )
assert http.parse_options_header(
- 'text/plain; q=0.5, text/html\n '
- 'text/x-dvi; q=0.8, text/x-c',
- multiple=True) == \
- ('text/plain', {'q': '0.5'}, "text/html", {},
- "text/x-dvi", {'q': '0.8'}, "text/x-c", {})
- assert http.parse_options_header('text/plain; q=0.5, text/html\n'
- ' '
- 'text/x-dvi; q=0.8, text/x-c') == \
- ('text/plain', {'q': '0.5'})
+ "text/plain; q=0.5, text/html\n text/x-dvi; q=0.8, text/x-c"
+ ) == ("text/plain", {"q": "0.5"})
# Issue #932
assert http.parse_options_header(
- 'form-data; '
- 'name="a_file"; '
- 'filename*=UTF-8\'\''
- '"%c2%a3%20and%20%e2%82%ac%20rates"') == \
- ('form-data', {'name': 'a_file',
- 'filename': u'\xa3 and \u20ac rates'})
+ "form-data; name=\"a_file\"; filename*=UTF-8''"
+ '"%c2%a3%20and%20%e2%82%ac%20rates"'
+ ) == ("form-data", {"name": "a_file", "filename": u"\xa3 and \u20ac rates"})
assert http.parse_options_header(
- 'form-data; '
- 'name*=UTF-8\'\'"%C5%AAn%C4%ADc%C5%8Dde%CC%BD"; '
- 'filename="some_file.txt"') == \
- ('form-data', {'name': u'\u016an\u012dc\u014dde\u033d',
- 'filename': 'some_file.txt'})
+ "form-data; name*=UTF-8''\"%C5%AAn%C4%ADc%C5%8Dde%CC%BD\"; "
+ 'filename="some_file.txt"'
+ ) == (
+ "form-data",
+ {"name": u"\u016an\u012dc\u014dde\u033d", "filename": "some_file.txt"},
+ )
def test_parse_options_header_value_with_quotes(self):
assert http.parse_options_header(
'form-data; name="file"; filename="t\'es\'t.txt"'
- ) == ('form-data', {'name': 'file', 'filename': "t'es't.txt"})
+ ) == ("form-data", {"name": "file", "filename": "t'es't.txt"})
assert http.parse_options_header(
- 'form-data; name="file"; filename*=UTF-8\'\'"\'🐍\'.txt"'
- ) == ('form-data', {'name': 'file', 'filename': u"'🐍'.txt"})
+ "form-data; name=\"file\"; filename*=UTF-8''\"'🐍'.txt\""
+ ) == ("form-data", {"name": "file", "filename": u"'🐍'.txt"})
def test_parse_options_header_broken_values(self):
# Issue #995
- assert http.parse_options_header(' ') == ('', {})
- assert http.parse_options_header(' , ') == ('', {})
- assert http.parse_options_header(' ; ') == ('', {})
- assert http.parse_options_header(' ,; ') == ('', {})
- assert http.parse_options_header(' , a ') == ('', {})
- assert http.parse_options_header(' ; a ') == ('', {})
+ assert http.parse_options_header(" ") == ("", {})
+ assert http.parse_options_header(" , ") == ("", {})
+ assert http.parse_options_header(" ; ") == ("", {})
+ assert http.parse_options_header(" ,; ") == ("", {})
+ assert http.parse_options_header(" , a ") == ("", {})
+ assert http.parse_options_header(" ; a ") == ("", {})
def test_dump_options_header(self):
- assert http.dump_options_header('foo', {'bar': 42}) == \
- 'foo; bar=42'
- assert http.dump_options_header('foo', {'bar': 42, 'fizz': None}) in \
- ('foo; bar=42; fizz', 'foo; fizz; bar=42')
+ assert http.dump_options_header("foo", {"bar": 42}) == "foo; bar=42"
+ assert http.dump_options_header("foo", {"bar": 42, "fizz": None}) in (
+ "foo; bar=42; fizz",
+ "foo; fizz; bar=42",
+ )
def test_dump_header(self):
- assert http.dump_header([1, 2, 3]) == '1, 2, 3'
+ assert http.dump_header([1, 2, 3]) == "1, 2, 3"
assert http.dump_header([1, 2, 3], allow_token=False) == '"1", "2", "3"'
- assert http.dump_header({'foo': 'bar'}, allow_token=False) == 'foo="bar"'
- assert http.dump_header({'foo': 'bar'}) == 'foo=bar'
+ assert http.dump_header({"foo": "bar"}, allow_token=False) == 'foo="bar"'
+ assert http.dump_header({"foo": "bar"}) == "foo=bar"
def test_is_resource_modified(self):
env = create_environ()
# ignore POST
- env['REQUEST_METHOD'] = 'POST'
- assert not http.is_resource_modified(env, etag='testing')
- env['REQUEST_METHOD'] = 'GET'
+ env["REQUEST_METHOD"] = "POST"
+ assert not http.is_resource_modified(env, etag="testing")
+ env["REQUEST_METHOD"] = "GET"
# etagify from data
- pytest.raises(TypeError, http.is_resource_modified, env,
- data='42', etag='23')
- env['HTTP_IF_NONE_MATCH'] = http.generate_etag(b'awesome')
- assert not http.is_resource_modified(env, data=b'awesome')
+ pytest.raises(TypeError, http.is_resource_modified, env, data="42", etag="23")
+ env["HTTP_IF_NONE_MATCH"] = http.generate_etag(b"awesome")
+ assert not http.is_resource_modified(env, data=b"awesome")
- env['HTTP_IF_MODIFIED_SINCE'] = http.http_date(datetime(2008, 1, 1, 12, 30))
- assert not http.is_resource_modified(env,
- last_modified=datetime(2008, 1, 1, 12, 00))
- assert http.is_resource_modified(env,
- last_modified=datetime(2008, 1, 1, 13, 00))
+ env["HTTP_IF_MODIFIED_SINCE"] = http.http_date(datetime(2008, 1, 1, 12, 30))
+ assert not http.is_resource_modified(
+ env, last_modified=datetime(2008, 1, 1, 12, 00)
+ )
+ assert http.is_resource_modified(
+ env, last_modified=datetime(2008, 1, 1, 13, 00)
+ )
def test_is_resource_modified_for_range_requests(self):
env = create_environ()
- env['HTTP_IF_MODIFIED_SINCE'] = http.http_date(datetime(2008, 1, 1, 12, 30))
- env['HTTP_IF_RANGE'] = http.generate_etag(b'awesome_if_range')
+ env["HTTP_IF_MODIFIED_SINCE"] = http.http_date(datetime(2008, 1, 1, 12, 30))
+ env["HTTP_IF_RANGE"] = http.generate_etag(b"awesome_if_range")
# Range header not present, so If-Range should be ignored
- assert not http.is_resource_modified(env, data=b'not_the_same',
- ignore_if_range=False,
- last_modified=datetime(2008, 1, 1, 12, 30))
-
- env['HTTP_RANGE'] = ''
- assert not http.is_resource_modified(env, data=b'awesome_if_range',
- ignore_if_range=False)
- assert http.is_resource_modified(env, data=b'not_the_same',
- ignore_if_range=False)
-
- env['HTTP_IF_RANGE'] = http.http_date(datetime(2008, 1, 1, 13, 30))
- assert http.is_resource_modified(env, last_modified=datetime(2008, 1, 1, 14, 00),
- ignore_if_range=False)
- assert not http.is_resource_modified(env, last_modified=datetime(2008, 1, 1, 13, 30),
- ignore_if_range=False)
- assert http.is_resource_modified(env, last_modified=datetime(2008, 1, 1, 13, 30),
- ignore_if_range=True)
+ assert not http.is_resource_modified(
+ env,
+ data=b"not_the_same",
+ ignore_if_range=False,
+ last_modified=datetime(2008, 1, 1, 12, 30),
+ )
+
+ env["HTTP_RANGE"] = ""
+ assert not http.is_resource_modified(
+ env, data=b"awesome_if_range", ignore_if_range=False
+ )
+ assert http.is_resource_modified(
+ env, data=b"not_the_same", ignore_if_range=False
+ )
+
+ env["HTTP_IF_RANGE"] = http.http_date(datetime(2008, 1, 1, 13, 30))
+ assert http.is_resource_modified(
+ env, last_modified=datetime(2008, 1, 1, 14, 00), ignore_if_range=False
+ )
+ assert not http.is_resource_modified(
+ env, last_modified=datetime(2008, 1, 1, 13, 30), ignore_if_range=False
+ )
+ assert http.is_resource_modified(
+ env, last_modified=datetime(2008, 1, 1, 13, 30), ignore_if_range=True
+ )
def test_date_formatting(self):
- assert http.cookie_date(0) == 'Thu, 01-Jan-1970 00:00:00 GMT'
- assert http.cookie_date(datetime(1970, 1, 1)) == 'Thu, 01-Jan-1970 00:00:00 GMT'
- assert http.http_date(0) == 'Thu, 01 Jan 1970 00:00:00 GMT'
- assert http.http_date(datetime(1970, 1, 1)) == 'Thu, 01 Jan 1970 00:00:00 GMT'
+ assert http.cookie_date(0) == "Thu, 01-Jan-1970 00:00:00 GMT"
+ assert http.cookie_date(datetime(1970, 1, 1)) == "Thu, 01-Jan-1970 00:00:00 GMT"
+ assert http.http_date(0) == "Thu, 01 Jan 1970 00:00:00 GMT"
+ assert http.http_date(datetime(1970, 1, 1)) == "Thu, 01 Jan 1970 00:00:00 GMT"
def test_cookies(self):
strict_eq(
- dict(http.parse_cookie('dismiss-top=6; CP=null*; PHPSESSID=0a539d42abc001cd'
- 'c762809248d4beed; a=42; b="\\\";"')),
+ dict(
+ http.parse_cookie(
+ "dismiss-top=6; CP=null*; PHPSESSID=0a539d42abc001cd"
+ 'c762809248d4beed; a=42; b="\\";"'
+ )
+ ),
{
- 'CP': u'null*',
- 'PHPSESSID': u'0a539d42abc001cdc762809248d4beed',
- 'a': u'42',
- 'dismiss-top': u'6',
- 'b': u'\";'
- }
- )
- rv = http.dump_cookie('foo', 'bar baz blub', 360, httponly=True,
- sync_expires=False)
+ "CP": u"null*",
+ "PHPSESSID": u"0a539d42abc001cdc762809248d4beed",
+ "a": u"42",
+ "dismiss-top": u"6",
+ "b": u'";',
+ },
+ )
+ rv = http.dump_cookie(
+ "foo", "bar baz blub", 360, httponly=True, sync_expires=False
+ )
assert type(rv) is str
- assert set(rv.split('; ')) == set(['HttpOnly', 'Max-Age=360',
- 'Path=/', 'foo="bar baz blub"'])
+ assert set(rv.split("; ")) == {
+ "HttpOnly",
+ "Max-Age=360",
+ "Path=/",
+ 'foo="bar baz blub"',
+ }
- strict_eq(dict(http.parse_cookie('fo234{=bar; blub=Blah')),
- {'fo234{': u'bar', 'blub': u'Blah'})
+ strict_eq(
+ dict(http.parse_cookie("fo234{=bar; blub=Blah")),
+ {"fo234{": u"bar", "blub": u"Blah"},
+ )
- strict_eq(http.dump_cookie('key', 'xxx/'), 'key=xxx/; Path=/')
- strict_eq(http.dump_cookie('key', 'xxx='), 'key=xxx=; Path=/')
+ strict_eq(http.dump_cookie("key", "xxx/"), "key=xxx/; Path=/")
+ strict_eq(http.dump_cookie("key", "xxx="), "key=xxx=; Path=/")
def test_bad_cookies(self):
strict_eq(
- dict(http.parse_cookie(
- 'first=IamTheFirst ; a=1; oops ; a=2 ;second = andMeTwo;'
- )),
- {
- 'first': u'IamTheFirst',
- 'a': u'2',
- 'oops': u'',
- 'second': u'andMeTwo',
- }
+ dict(
+ http.parse_cookie(
+ "first=IamTheFirst ; a=1; oops ; a=2 ;second = andMeTwo;"
+ )
+ ),
+ {"first": u"IamTheFirst", "a": u"2", "oops": u"", "second": u"andMeTwo"},
)
def test_empty_keys_are_ignored(self):
strict_eq(
- dict(http.parse_cookie(
- 'first=IamTheFirst ; a=1; a=2 ;second=andMeTwo; ; '
- )),
- {
- 'first': u'IamTheFirst',
- 'a': u'2',
- 'second': u'andMeTwo'
- }
+ dict(
+ http.parse_cookie("first=IamTheFirst ; a=1; a=2 ;second=andMeTwo; ; ")
+ ),
+ {"first": u"IamTheFirst", "a": u"2", "second": u"andMeTwo"},
)
def test_cookie_quoting(self):
val = http.dump_cookie("foo", "?foo")
strict_eq(val, 'foo="?foo"; Path=/')
- strict_eq(dict(http.parse_cookie(val)), {'foo': u'?foo'})
+ strict_eq(dict(http.parse_cookie(val)), {"foo": u"?foo"})
- strict_eq(dict(http.parse_cookie(r'foo="foo\054bar"')),
- {'foo': u'foo,bar'})
+ strict_eq(dict(http.parse_cookie(r'foo="foo\054bar"')), {"foo": u"foo,bar"})
def test_cookie_domain_resolving(self):
- val = http.dump_cookie('foo', 'bar', domain=u'\N{SNOWMAN}.com')
- strict_eq(val, 'foo=bar; Domain=xn--n3h.com; Path=/')
+ val = http.dump_cookie("foo", "bar", domain=u"\N{SNOWMAN}.com")
+ strict_eq(val, "foo=bar; Domain=xn--n3h.com; Path=/")
def test_cookie_unicode_dumping(self):
- val = http.dump_cookie('foo', u'\N{SNOWMAN}')
+ val = http.dump_cookie("foo", u"\N{SNOWMAN}")
h = datastructures.Headers()
- h.add('Set-Cookie', val)
- assert h['Set-Cookie'] == 'foo="\\342\\230\\203"; Path=/'
+ h.add("Set-Cookie", val)
+ assert h["Set-Cookie"] == 'foo="\\342\\230\\203"; Path=/'
- cookies = http.parse_cookie(h['Set-Cookie'])
- assert cookies['foo'] == u'\N{SNOWMAN}'
+ cookies = http.parse_cookie(h["Set-Cookie"])
+ assert cookies["foo"] == u"\N{SNOWMAN}"
def test_cookie_unicode_keys(self):
# Yes, this is technically against the spec but happens
- val = http.dump_cookie(u'fö', u'fö')
- assert val == wsgi_encoding_dance(u'fö="f\\303\\266"; Path=/', 'utf-8')
+ val = http.dump_cookie(u"fö", u"fö")
+ assert val == wsgi_encoding_dance(u'fö="f\\303\\266"; Path=/', "utf-8")
cookies = http.parse_cookie(val)
- assert cookies[u'fö'] == u'fö'
+ assert cookies[u"fö"] == u"fö"
def test_cookie_unicode_parsing(self):
# This is actually a correct test. This is what is being submitted
# by firefox if you set an unicode cookie and we get the cookie sent
# in on Python 3 under PEP 3333.
- cookies = http.parse_cookie(u'fö=fö')
- assert cookies[u'fö'] == u'fö'
+ cookies = http.parse_cookie(u"fö=fö")
+ assert cookies[u"fö"] == u"fö"
def test_cookie_domain_encoding(self):
- val = http.dump_cookie('foo', 'bar', domain=u'\N{SNOWMAN}.com')
- strict_eq(val, 'foo=bar; Domain=xn--n3h.com; Path=/')
+ val = http.dump_cookie("foo", "bar", domain=u"\N{SNOWMAN}.com")
+ strict_eq(val, "foo=bar; Domain=xn--n3h.com; Path=/")
- val = http.dump_cookie('foo', 'bar', domain=u'.\N{SNOWMAN}.com')
- strict_eq(val, 'foo=bar; Domain=.xn--n3h.com; Path=/')
+ val = http.dump_cookie("foo", "bar", domain=u".\N{SNOWMAN}.com")
+ strict_eq(val, "foo=bar; Domain=.xn--n3h.com; Path=/")
- val = http.dump_cookie('foo', 'bar', domain=u'.foo.com')
- strict_eq(val, 'foo=bar; Domain=.foo.com; Path=/')
+ val = http.dump_cookie("foo", "bar", domain=u".foo.com")
+ strict_eq(val, "foo=bar; Domain=.foo.com; Path=/")
def test_cookie_maxsize(self, recwarn):
- val = http.dump_cookie('foo', 'bar' * 1360 + 'b')
+ val = http.dump_cookie("foo", "bar" * 1360 + "b")
assert len(recwarn) == 0
assert len(val) == 4093
- http.dump_cookie('foo', 'bar' * 1360 + 'ba')
+ http.dump_cookie("foo", "bar" * 1360 + "ba")
assert len(recwarn) == 1
w = recwarn.pop()
- assert 'cookie is too large' in str(w.message)
+ assert "cookie is too large" in str(w.message)
- http.dump_cookie('foo', b'w' * 502, max_size=512)
+ http.dump_cookie("foo", b"w" * 502, max_size=512)
assert len(recwarn) == 1
w = recwarn.pop()
- assert 'the limit is 512 bytes' in str(w.message)
-
- @pytest.mark.parametrize('input, expected', [
- ('strict', 'foo=bar; Path=/; SameSite=Strict'),
- ('lax', 'foo=bar; Path=/; SameSite=Lax'),
- (None, 'foo=bar; Path=/'),
- ])
+ assert "the limit is 512 bytes" in str(w.message)
+
+ @pytest.mark.parametrize(
+ "input, expected",
+ [
+ ("strict", "foo=bar; Path=/; SameSite=Strict"),
+ ("lax", "foo=bar; Path=/; SameSite=Lax"),
+ (None, "foo=bar; Path=/"),
+ ],
+ )
def test_cookie_samesite_attribute(self, input, expected):
- val = http.dump_cookie('foo', 'bar', samesite=input)
+ val = http.dump_cookie("foo", "bar", samesite=input)
strict_eq(val, expected)
class TestRange(object):
-
def test_if_range_parsing(self):
rv = http.parse_if_range_header('"Test"')
- assert rv.etag == 'Test'
+ assert rv.etag == "Test"
assert rv.date is None
assert rv.to_header() == '"Test"'
# weak information is dropped
rv = http.parse_if_range_header('W/"Test"')
- assert rv.etag == 'Test'
+ assert rv.etag == "Test"
assert rv.date is None
assert rv.to_header() == '"Test"'
# broken etags are supported too
- rv = http.parse_if_range_header('bullshit')
- assert rv.etag == 'bullshit'
+ rv = http.parse_if_range_header("bullshit")
+ assert rv.etag == "bullshit"
assert rv.date is None
assert rv.to_header() == '"bullshit"'
- rv = http.parse_if_range_header('Thu, 01 Jan 1970 00:00:00 GMT')
+ rv = http.parse_if_range_header("Thu, 01 Jan 1970 00:00:00 GMT")
assert rv.etag is None
assert rv.date == datetime(1970, 1, 1)
- assert rv.to_header() == 'Thu, 01 Jan 1970 00:00:00 GMT'
+ assert rv.to_header() == "Thu, 01 Jan 1970 00:00:00 GMT"
- for x in '', None:
+ for x in "", None:
rv = http.parse_if_range_header(x)
assert rv.etag is None
assert rv.date is None
- assert rv.to_header() == ''
+ assert rv.to_header() == ""
def test_range_parsing(self):
- rv = http.parse_range_header('bytes=52')
+ rv = http.parse_range_header("bytes=52")
assert rv is None
- rv = http.parse_range_header('bytes=52-')
- assert rv.units == 'bytes'
+ rv = http.parse_range_header("bytes=52-")
+ assert rv.units == "bytes"
assert rv.ranges == [(52, None)]
- assert rv.to_header() == 'bytes=52-'
+ assert rv.to_header() == "bytes=52-"
- rv = http.parse_range_header('bytes=52-99')
- assert rv.units == 'bytes'
+ rv = http.parse_range_header("bytes=52-99")
+ assert rv.units == "bytes"
assert rv.ranges == [(52, 100)]
- assert rv.to_header() == 'bytes=52-99'
+ assert rv.to_header() == "bytes=52-99"
- rv = http.parse_range_header('bytes=52-99,-1000')
- assert rv.units == 'bytes'
+ rv = http.parse_range_header("bytes=52-99,-1000")
+ assert rv.units == "bytes"
assert rv.ranges == [(52, 100), (-1000, None)]
- assert rv.to_header() == 'bytes=52-99,-1000'
+ assert rv.to_header() == "bytes=52-99,-1000"
- rv = http.parse_range_header('bytes = 1 - 100')
- assert rv.units == 'bytes'
+ rv = http.parse_range_header("bytes = 1 - 100")
+ assert rv.units == "bytes"
assert rv.ranges == [(1, 101)]
- assert rv.to_header() == 'bytes=1-100'
+ assert rv.to_header() == "bytes=1-100"
- rv = http.parse_range_header('AWesomes=0-999')
- assert rv.units == 'awesomes'
+ rv = http.parse_range_header("AWesomes=0-999")
+ assert rv.units == "awesomes"
assert rv.ranges == [(0, 1000)]
- assert rv.to_header() == 'awesomes=0-999'
+ assert rv.to_header() == "awesomes=0-999"
- rv = http.parse_range_header('bytes=-')
+ rv = http.parse_range_header("bytes=-")
assert rv is None
- rv = http.parse_range_header('bytes=bad')
+ rv = http.parse_range_header("bytes=bad")
assert rv is None
- rv = http.parse_range_header('bytes=bad-1')
+ rv = http.parse_range_header("bytes=bad-1")
assert rv is None
- rv = http.parse_range_header('bytes=-bad')
+ rv = http.parse_range_header("bytes=-bad")
assert rv is None
- rv = http.parse_range_header('bytes=52-99, bad')
+ rv = http.parse_range_header("bytes=52-99, bad")
assert rv is None
def test_content_range_parsing(self):
- rv = http.parse_content_range_header('bytes 0-98/*')
- assert rv.units == 'bytes'
+ rv = http.parse_content_range_header("bytes 0-98/*")
+ assert rv.units == "bytes"
assert rv.start == 0
assert rv.stop == 99
assert rv.length is None
- assert rv.to_header() == 'bytes 0-98/*'
+ assert rv.to_header() == "bytes 0-98/*"
- rv = http.parse_content_range_header('bytes 0-98/*asdfsa')
+ rv = http.parse_content_range_header("bytes 0-98/*asdfsa")
assert rv is None
- rv = http.parse_content_range_header('bytes 0-99/100')
- assert rv.to_header() == 'bytes 0-99/100'
+ rv = http.parse_content_range_header("bytes 0-99/100")
+ assert rv.to_header() == "bytes 0-99/100"
rv.start = None
rv.stop = None
- assert rv.units == 'bytes'
- assert rv.to_header() == 'bytes */100'
+ assert rv.units == "bytes"
+ assert rv.to_header() == "bytes */100"
- rv = http.parse_content_range_header('bytes */100')
+ rv = http.parse_content_range_header("bytes */100")
assert rv.start is None
assert rv.stop is None
assert rv.length == 100
- assert rv.units == 'bytes'
+ assert rv.units == "bytes"
class TestRegression(object):
-
def test_best_match_works(self):
# was a bug in 0.6
- rv = http.parse_accept_header('foo=,application/xml,application/xhtml+xml,'
- 'text/html;q=0.9,text/plain;q=0.8,'
- 'image/png,*/*;q=0.5',
- datastructures.MIMEAccept).best_match(['foo/bar'])
- assert rv == 'foo/bar'
+ rv = http.parse_accept_header(
+ "foo=,application/xml,application/xhtml+xml,"
+ "text/html;q=0.9,text/plain;q=0.8,"
+ "image/png,*/*;q=0.5",
+ datastructures.MIMEAccept,
+ ).best_match(["foo/bar"])
+ assert rv == "foo/bar"
diff --git a/tests/test_internal.py b/tests/test_internal.py
index e272b58e..ca2f92cb 100644
--- a/tests/test_internal.py
+++ b/tests/test_internal.py
@@ -8,15 +8,16 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-import pytest
-
from datetime import datetime
-from warnings import filterwarnings, resetwarnings
+from warnings import filterwarnings
+from warnings import resetwarnings
-from werkzeug.wrappers import Request, Response
+import pytest
from werkzeug import _internal as internal
from werkzeug.test import create_environ
+from werkzeug.wrappers import Request
+from werkzeug.wrappers import Response
def test_date_to_unix():
@@ -28,44 +29,44 @@ def test_date_to_unix():
def test_easteregg():
- req = Request.from_values('/?macgybarchakku')
+ req = Request.from_values("/?macgybarchakku")
resp = Response.force_type(internal._easteregg(None), req)
- assert b'About Werkzeug' in resp.get_data()
- assert b'the Swiss Army knife of Python web development' in resp.get_data()
+ assert b"About Werkzeug" in resp.get_data()
+ assert b"the Swiss Army knife of Python web development" in resp.get_data()
def test_wrapper_internals():
- req = Request.from_values(data={'foo': 'bar'}, method='POST')
+ req = Request.from_values(data={"foo": "bar"}, method="POST")
req._load_form_data()
- assert req.form.to_dict() == {'foo': 'bar'}
+ assert req.form.to_dict() == {"foo": "bar"}
# second call does not break
req._load_form_data()
- assert req.form.to_dict() == {'foo': 'bar'}
+ assert req.form.to_dict() == {"foo": "bar"}
# check reprs
assert repr(req) == "<Request 'http://localhost/' [POST]>"
resp = Response()
- assert repr(resp) == '<Response 0 bytes [200 OK]>'
- resp.set_data('Hello World!')
- assert repr(resp) == '<Response 12 bytes [200 OK]>'
- resp.response = iter(['Test'])
- assert repr(resp) == '<Response streamed [200 OK]>'
+ assert repr(resp) == "<Response 0 bytes [200 OK]>"
+ resp.set_data("Hello World!")
+ assert repr(resp) == "<Response 12 bytes [200 OK]>"
+ resp.response = iter(["Test"])
+ assert repr(resp) == "<Response streamed [200 OK]>"
# unicode data does not set content length
- response = Response([u'Hällo Wörld'])
+ response = Response([u"Hällo Wörld"])
headers = response.get_wsgi_headers(create_environ())
- assert u'Content-Length' not in headers
+ assert u"Content-Length" not in headers
- response = Response([u'Hällo Wörld'.encode('utf-8')])
+ response = Response([u"Hällo Wörld".encode("utf-8")])
headers = response.get_wsgi_headers(create_environ())
- assert u'Content-Length' in headers
+ assert u"Content-Length" in headers
# check for internal warnings
- filterwarnings('error', category=Warning)
+ filterwarnings("error", category=Warning)
response = Response()
environ = create_environ()
- response.response = 'What the...?'
+ response.response = "What the...?"
pytest.raises(Warning, lambda: list(response.iter_encoded()))
pytest.raises(Warning, lambda: list(response.get_app_iter(environ)))
response.direct_passthrough = True
diff --git a/tests/test_local.py b/tests/test_local.py
index 26f63eba..9bbd0fe8 100644
--- a/tests/test_local.py
+++ b/tests/test_local.py
@@ -8,13 +8,13 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-import pytest
-
-import time
import copy
+import time
from functools import partial
from threading import Thread
+import pytest
+
from werkzeug import local
@@ -28,8 +28,8 @@ def test_basic_local():
ns.foo = idx
time.sleep(0.02)
values.append(ns.foo)
- threads = [Thread(target=value_setter, args=(x,))
- for x in [1, 2, 3]]
+
+ threads = [Thread(target=value_setter, args=(x,)) for x in [1, 2, 3]]
for thread in threads:
thread.start()
time.sleep(0.2)
@@ -37,6 +37,7 @@ def test_basic_local():
def delfoo():
del ns.foo
+
delfoo()
pytest.raises(AttributeError, lambda: ns.foo)
pytest.raises(AttributeError, delfoo)
@@ -48,7 +49,7 @@ def test_local_release():
ns = local.Local()
ns.foo = 42
local.release_local(ns)
- assert not hasattr(ns, 'foo')
+ assert not hasattr(ns, "foo")
ls = local.LocalStack()
ls.push(42)
@@ -122,7 +123,7 @@ def test_local_stack():
assert proxy == (1, 2)
ls.pop()
ls.pop()
- assert repr(proxy) == '<LocalProxy unbound>'
+ assert repr(proxy) == "<LocalProxy unbound>"
assert ident not in ls._local.__storage__
@@ -144,18 +145,18 @@ def test_custom_idents():
local.LocalManager([ns, stack], ident_func=lambda: ident)
ns.foo = 42
- stack.push({'foo': 42})
+ stack.push({"foo": 42})
ident = 1
ns.foo = 23
- stack.push({'foo': 23})
+ stack.push({"foo": 23})
ident = 0
assert ns.foo == 42
- assert stack.top['foo'] == 42
+ assert stack.top["foo"] == 42
stack.pop()
assert stack.top is None
ident = 1
assert ns.foo == 23
- assert stack.top['foo'] == 23
+ assert stack.top["foo"] == 23
stack.pop()
assert stack.top is None
@@ -169,6 +170,7 @@ def test_deepcopy_on_proxy():
def __deepcopy__(self, memo):
return self
+
f = Foo()
p = local.LocalProxy(lambda: f)
assert p.attr == 42
@@ -186,7 +188,7 @@ def test_deepcopy_on_proxy():
def test_local_proxy_wrapped_attribute():
class SomeClassWithWrapped(object):
- __wrapped__ = 'wrapped'
+ __wrapped__ = "wrapped"
def lookup_func():
return 42
@@ -203,5 +205,5 @@ def test_local_proxy_wrapped_attribute():
ns.foo = SomeClassWithWrapped()
ns.bar = 42
- assert ns('foo').__wrapped__ == 'wrapped'
- pytest.raises(AttributeError, lambda: ns('bar').__wrapped__)
+ assert ns("foo").__wrapped__ == "wrapped"
+ pytest.raises(AttributeError, lambda: ns("bar").__wrapped__)
diff --git a/tests/test_routing.py b/tests/test_routing.py
index fa604ee0..835ca684 100644
--- a/tests/test_routing.py
+++ b/tests/test_routing.py
@@ -8,286 +8,313 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-import pytest
-
import uuid
-from tests import strict_eq
+import pytest
+from . import strict_eq
from werkzeug import routing as r
-from werkzeug.wrappers import Response
-from werkzeug.datastructures import ImmutableDict, MultiDict
+from werkzeug.datastructures import ImmutableDict
+from werkzeug.datastructures import MultiDict
from werkzeug.test import create_environ
+from werkzeug.wrappers import Response
def test_basic_routing():
- map = r.Map([
- r.Rule('/', endpoint='index'),
- r.Rule('/foo', endpoint='foo'),
- r.Rule('/bar/', endpoint='bar')
- ])
- adapter = map.bind('example.org', '/')
- assert adapter.match('/') == ('index', {})
- assert adapter.match('/foo') == ('foo', {})
- assert adapter.match('/bar/') == ('bar', {})
- pytest.raises(r.RequestRedirect, lambda: adapter.match('/bar'))
- pytest.raises(r.NotFound, lambda: adapter.match('/blub'))
-
- adapter = map.bind('example.org', '/test')
+ map = r.Map(
+ [
+ r.Rule("/", endpoint="index"),
+ r.Rule("/foo", endpoint="foo"),
+ r.Rule("/bar/", endpoint="bar"),
+ ]
+ )
+ adapter = map.bind("example.org", "/")
+ assert adapter.match("/") == ("index", {})
+ assert adapter.match("/foo") == ("foo", {})
+ assert adapter.match("/bar/") == ("bar", {})
+ pytest.raises(r.RequestRedirect, lambda: adapter.match("/bar"))
+ pytest.raises(r.NotFound, lambda: adapter.match("/blub"))
+
+ adapter = map.bind("example.org", "/test")
with pytest.raises(r.RequestRedirect) as excinfo:
- adapter.match('/bar')
- assert excinfo.value.new_url == 'http://example.org/test/bar/'
+ adapter.match("/bar")
+ assert excinfo.value.new_url == "http://example.org/test/bar/"
- adapter = map.bind('example.org', '/')
+ adapter = map.bind("example.org", "/")
with pytest.raises(r.RequestRedirect) as excinfo:
- adapter.match('/bar')
- assert excinfo.value.new_url == 'http://example.org/bar/'
+ adapter.match("/bar")
+ assert excinfo.value.new_url == "http://example.org/bar/"
- adapter = map.bind('example.org', '/')
+ adapter = map.bind("example.org", "/")
with pytest.raises(r.RequestRedirect) as excinfo:
- adapter.match('/bar', query_args={'aha': 'muhaha'})
- assert excinfo.value.new_url == 'http://example.org/bar/?aha=muhaha'
+ adapter.match("/bar", query_args={"aha": "muhaha"})
+ assert excinfo.value.new_url == "http://example.org/bar/?aha=muhaha"
- adapter = map.bind('example.org', '/')
+ adapter = map.bind("example.org", "/")
with pytest.raises(r.RequestRedirect) as excinfo:
- adapter.match('/bar', query_args='aha=muhaha')
- assert excinfo.value.new_url == 'http://example.org/bar/?aha=muhaha'
+ adapter.match("/bar", query_args="aha=muhaha")
+ assert excinfo.value.new_url == "http://example.org/bar/?aha=muhaha"
- adapter = map.bind_to_environ(create_environ('/bar?foo=bar',
- 'http://example.org/'))
+ adapter = map.bind_to_environ(create_environ("/bar?foo=bar", "http://example.org/"))
with pytest.raises(r.RequestRedirect) as excinfo:
adapter.match()
- assert excinfo.value.new_url == 'http://example.org/bar/?foo=bar'
+ assert excinfo.value.new_url == "http://example.org/bar/?foo=bar"
def test_strict_slashes_redirect():
- map = r.Map([
- r.Rule('/bar/', endpoint='get', methods=["GET"]),
- r.Rule('/bar', endpoint='post', methods=["POST"]),
- r.Rule('/foo/', endpoint='foo', methods=["POST"]),
- ])
- adapter = map.bind('example.org', '/')
+ map = r.Map(
+ [
+ r.Rule("/bar/", endpoint="get", methods=["GET"]),
+ r.Rule("/bar", endpoint="post", methods=["POST"]),
+ r.Rule("/foo/", endpoint="foo", methods=["POST"]),
+ ]
+ )
+ adapter = map.bind("example.org", "/")
# Check if the actual routes works
- assert adapter.match('/bar/', method='GET') == ('get', {})
- assert adapter.match('/bar', method='POST') == ('post', {})
+ assert adapter.match("/bar/", method="GET") == ("get", {})
+ assert adapter.match("/bar", method="POST") == ("post", {})
# Check if exceptions are correct
- pytest.raises(r.RequestRedirect, adapter.match, '/bar', method='GET')
- pytest.raises(r.MethodNotAllowed, adapter.match, '/bar/', method='POST')
+ pytest.raises(r.RequestRedirect, adapter.match, "/bar", method="GET")
+ pytest.raises(r.MethodNotAllowed, adapter.match, "/bar/", method="POST")
with pytest.raises(r.RequestRedirect) as error_info:
- adapter.match('/foo', method='POST')
+ adapter.match("/foo", method="POST")
assert error_info.value.code == 308
# Check differently defined order
- map = r.Map([
- r.Rule('/bar', endpoint='post', methods=["POST"]),
- r.Rule('/bar/', endpoint='get', methods=["GET"]),
- ])
- adapter = map.bind('example.org', '/')
+ map = r.Map(
+ [
+ r.Rule("/bar", endpoint="post", methods=["POST"]),
+ r.Rule("/bar/", endpoint="get", methods=["GET"]),
+ ]
+ )
+ adapter = map.bind("example.org", "/")
# Check if the actual routes works
- assert adapter.match('/bar/', method='GET') == ('get', {})
- assert adapter.match('/bar', method='POST') == ('post', {})
+ assert adapter.match("/bar/", method="GET") == ("get", {})
+ assert adapter.match("/bar", method="POST") == ("post", {})
# Check if exceptions are correct
- pytest.raises(r.RequestRedirect, adapter.match, '/bar', method='GET')
- pytest.raises(r.MethodNotAllowed, adapter.match, '/bar/', method='POST')
+ pytest.raises(r.RequestRedirect, adapter.match, "/bar", method="GET")
+ pytest.raises(r.MethodNotAllowed, adapter.match, "/bar/", method="POST")
# Check what happens when only slash route is defined
- map = r.Map([
- r.Rule('/bar/', endpoint='get', methods=["GET"]),
- ])
- adapter = map.bind('example.org', '/')
+ map = r.Map([r.Rule("/bar/", endpoint="get", methods=["GET"])])
+ adapter = map.bind("example.org", "/")
# Check if the actual routes works
- assert adapter.match('/bar/', method='GET') == ('get', {})
+ assert adapter.match("/bar/", method="GET") == ("get", {})
# Check if exceptions are correct
- pytest.raises(r.RequestRedirect, adapter.match, '/bar', method='GET')
- pytest.raises(r.MethodNotAllowed, adapter.match, '/bar/', method='POST')
- pytest.raises(r.MethodNotAllowed, adapter.match, '/bar', method='POST')
+ pytest.raises(r.RequestRedirect, adapter.match, "/bar", method="GET")
+ pytest.raises(r.MethodNotAllowed, adapter.match, "/bar/", method="POST")
+ pytest.raises(r.MethodNotAllowed, adapter.match, "/bar", method="POST")
def test_environ_defaults():
environ = create_environ("/foo")
- strict_eq(environ["PATH_INFO"], '/foo')
+ strict_eq(environ["PATH_INFO"], "/foo")
m = r.Map([r.Rule("/foo", endpoint="foo"), r.Rule("/bar", endpoint="bar")])
a = m.bind_to_environ(environ)
- strict_eq(a.match("/foo"), ('foo', {}))
- strict_eq(a.match(), ('foo', {}))
- strict_eq(a.match("/bar"), ('bar', {}))
+ strict_eq(a.match("/foo"), ("foo", {}))
+ strict_eq(a.match(), ("foo", {}))
+ strict_eq(a.match("/bar"), ("bar", {}))
pytest.raises(r.NotFound, a.match, "/bars")
def test_environ_nonascii_pathinfo():
- environ = create_environ(u'/лошадь')
- m = r.Map([
- r.Rule(u'/', endpoint='index'),
- r.Rule(u'/лошадь', endpoint='horse')
- ])
+ environ = create_environ(u"/лошадь")
+ m = r.Map([r.Rule(u"/", endpoint="index"), r.Rule(u"/лошадь", endpoint="horse")])
a = m.bind_to_environ(environ)
- strict_eq(a.match(u'/'), ('index', {}))
- strict_eq(a.match(u'/лошадь'), ('horse', {}))
- pytest.raises(r.NotFound, a.match, u'/барсук')
+ strict_eq(a.match(u"/"), ("index", {}))
+ strict_eq(a.match(u"/лошадь"), ("horse", {}))
+ pytest.raises(r.NotFound, a.match, u"/барсук")
def test_basic_building():
- map = r.Map([
- r.Rule('/', endpoint='index'),
- r.Rule('/foo', endpoint='foo'),
- r.Rule('/bar/<baz>', endpoint='bar'),
- r.Rule('/bar/<int:bazi>', endpoint='bari'),
- r.Rule('/bar/<float:bazf>', endpoint='barf'),
- r.Rule('/bar/<path:bazp>', endpoint='barp'),
- r.Rule('/hehe', endpoint='blah', subdomain='blah')
- ])
- adapter = map.bind('example.org', '/', subdomain='blah')
-
- assert adapter.build('index', {}) == 'http://example.org/'
- assert adapter.build('foo', {}) == 'http://example.org/foo'
- assert adapter.build('bar', {'baz': 'blub'}) == \
- 'http://example.org/bar/blub'
- assert adapter.build('bari', {'bazi': 50}) == 'http://example.org/bar/50'
- assert adapter.build('barf', {'bazf': 0.815}) == \
- 'http://example.org/bar/0.815'
- assert adapter.build('barp', {'bazp': 'la/di'}) == \
- 'http://example.org/bar/la/di'
- assert adapter.build('blah', {}) == '/hehe'
- pytest.raises(r.BuildError, lambda: adapter.build('urks'))
-
- adapter = map.bind('example.org', '/test', subdomain='blah')
- assert adapter.build('index', {}) == 'http://example.org/test/'
- assert adapter.build('foo', {}) == 'http://example.org/test/foo'
- assert adapter.build('bar', {'baz': 'blub'}) == \
- 'http://example.org/test/bar/blub'
- assert adapter.build('bari', {'bazi': 50}) == 'http://example.org/test/bar/50'
- assert adapter.build('barf', {'bazf': 0.815}) == 'http://example.org/test/bar/0.815'
- assert adapter.build('barp', {'bazp': 'la/di'}) == 'http://example.org/test/bar/la/di'
- assert adapter.build('blah', {}) == '/test/hehe'
-
- adapter = map.bind('example.org')
- assert adapter.build('foo', {}) == '/foo'
- assert adapter.build('foo', {}, force_external=True) == 'http://example.org/foo'
- adapter = map.bind('example.org', url_scheme='')
- assert adapter.build('foo', {}) == '/foo'
- assert adapter.build('foo', {}, force_external=True) == '//example.org/foo'
+ map = r.Map(
+ [
+ r.Rule("/", endpoint="index"),
+ r.Rule("/foo", endpoint="foo"),
+ r.Rule("/bar/<baz>", endpoint="bar"),
+ r.Rule("/bar/<int:bazi>", endpoint="bari"),
+ r.Rule("/bar/<float:bazf>", endpoint="barf"),
+ r.Rule("/bar/<path:bazp>", endpoint="barp"),
+ r.Rule("/hehe", endpoint="blah", subdomain="blah"),
+ ]
+ )
+ adapter = map.bind("example.org", "/", subdomain="blah")
+
+ assert adapter.build("index", {}) == "http://example.org/"
+ assert adapter.build("foo", {}) == "http://example.org/foo"
+ assert adapter.build("bar", {"baz": "blub"}) == "http://example.org/bar/blub"
+ assert adapter.build("bari", {"bazi": 50}) == "http://example.org/bar/50"
+ assert adapter.build("barf", {"bazf": 0.815}) == "http://example.org/bar/0.815"
+ assert adapter.build("barp", {"bazp": "la/di"}) == "http://example.org/bar/la/di"
+ assert adapter.build("blah", {}) == "/hehe"
+ pytest.raises(r.BuildError, lambda: adapter.build("urks"))
+
+ adapter = map.bind("example.org", "/test", subdomain="blah")
+ assert adapter.build("index", {}) == "http://example.org/test/"
+ assert adapter.build("foo", {}) == "http://example.org/test/foo"
+ assert adapter.build("bar", {"baz": "blub"}) == "http://example.org/test/bar/blub"
+ assert adapter.build("bari", {"bazi": 50}) == "http://example.org/test/bar/50"
+ assert adapter.build("barf", {"bazf": 0.815}) == "http://example.org/test/bar/0.815"
+ assert (
+ adapter.build("barp", {"bazp": "la/di"}) == "http://example.org/test/bar/la/di"
+ )
+ assert adapter.build("blah", {}) == "/test/hehe"
+
+ adapter = map.bind("example.org")
+ assert adapter.build("foo", {}) == "/foo"
+ assert adapter.build("foo", {}, force_external=True) == "http://example.org/foo"
+ adapter = map.bind("example.org", url_scheme="")
+ assert adapter.build("foo", {}) == "/foo"
+ assert adapter.build("foo", {}, force_external=True) == "//example.org/foo"
def test_long_build():
- long_args = dict(('v%d' % x, x) for x in range(10000))
- map = r.Map([
- r.Rule(''.join('/<%s>' % k for k in long_args.keys()), endpoint='bleep', build_only=True)
- ])
- adapter = map.bind('localhost', '/')
- url = adapter.build('bleep', long_args)
- url += '/'
+ long_args = dict(("v%d" % x, x) for x in range(10000))
+ map = r.Map(
+ [
+ r.Rule(
+ "".join("/<%s>" % k for k in long_args.keys()),
+ endpoint="bleep",
+ build_only=True,
+ )
+ ]
+ )
+ adapter = map.bind("localhost", "/")
+ url = adapter.build("bleep", long_args)
+ url += "/"
for v in long_args.values():
- assert '/%d' % v in url
+ assert "/%d" % v in url
def test_defaults():
- map = r.Map([
- r.Rule('/foo/', defaults={'page': 1}, endpoint='foo'),
- r.Rule('/foo/<int:page>', endpoint='foo')
- ])
- adapter = map.bind('example.org', '/')
+ map = r.Map(
+ [
+ r.Rule("/foo/", defaults={"page": 1}, endpoint="foo"),
+ r.Rule("/foo/<int:page>", endpoint="foo"),
+ ]
+ )
+ adapter = map.bind("example.org", "/")
- assert adapter.match('/foo/') == ('foo', {'page': 1})
- pytest.raises(r.RequestRedirect, lambda: adapter.match('/foo/1'))
- assert adapter.match('/foo/2') == ('foo', {'page': 2})
- assert adapter.build('foo', {}) == '/foo/'
- assert adapter.build('foo', {'page': 1}) == '/foo/'
- assert adapter.build('foo', {'page': 2}) == '/foo/2'
+ assert adapter.match("/foo/") == ("foo", {"page": 1})
+ pytest.raises(r.RequestRedirect, lambda: adapter.match("/foo/1"))
+ assert adapter.match("/foo/2") == ("foo", {"page": 2})
+ assert adapter.build("foo", {}) == "/foo/"
+ assert adapter.build("foo", {"page": 1}) == "/foo/"
+ assert adapter.build("foo", {"page": 2}) == "/foo/2"
def test_negative():
- map = r.Map([
- r.Rule('/foos/<int(signed=True):page>', endpoint='foos'),
- r.Rule('/bars/<float(signed=True):page>', endpoint='bars'),
- r.Rule('/foo/<int:page>', endpoint='foo'),
- r.Rule('/bar/<float:page>', endpoint='bar')
- ])
- adapter = map.bind('example.org', '/')
-
- assert adapter.match('/foos/-2') == ('foos', {'page': -2})
- assert adapter.match('/foos/-50') == ('foos', {'page': -50})
- assert adapter.match('/bars/-2.0') == ('bars', {'page': -2.0})
- assert adapter.match('/bars/-0.185') == ('bars', {'page': -0.185})
+ map = r.Map(
+ [
+ r.Rule("/foos/<int(signed=True):page>", endpoint="foos"),
+ r.Rule("/bars/<float(signed=True):page>", endpoint="bars"),
+ r.Rule("/foo/<int:page>", endpoint="foo"),
+ r.Rule("/bar/<float:page>", endpoint="bar"),
+ ]
+ )
+ adapter = map.bind("example.org", "/")
+
+ assert adapter.match("/foos/-2") == ("foos", {"page": -2})
+ assert adapter.match("/foos/-50") == ("foos", {"page": -50})
+ assert adapter.match("/bars/-2.0") == ("bars", {"page": -2.0})
+ assert adapter.match("/bars/-0.185") == ("bars", {"page": -0.185})
# Make sure signed values are rejected in unsigned mode
- pytest.raises(r.NotFound, lambda: adapter.match('/foo/-2'))
- pytest.raises(r.NotFound, lambda: adapter.match('/foo/-50'))
- pytest.raises(r.NotFound, lambda: adapter.match('/bar/-0.185'))
- pytest.raises(r.NotFound, lambda: adapter.match('/bar/-2.0'))
+ pytest.raises(r.NotFound, lambda: adapter.match("/foo/-2"))
+ pytest.raises(r.NotFound, lambda: adapter.match("/foo/-50"))
+ pytest.raises(r.NotFound, lambda: adapter.match("/bar/-0.185"))
+ pytest.raises(r.NotFound, lambda: adapter.match("/bar/-2.0"))
def test_greedy():
- map = r.Map([
- r.Rule('/foo', endpoint='foo'),
- r.Rule('/<path:bar>', endpoint='bar'),
- r.Rule('/<path:bar>/<path:blub>', endpoint='bar')
- ])
- adapter = map.bind('example.org', '/')
+ map = r.Map(
+ [
+ r.Rule("/foo", endpoint="foo"),
+ r.Rule("/<path:bar>", endpoint="bar"),
+ r.Rule("/<path:bar>/<path:blub>", endpoint="bar"),
+ ]
+ )
+ adapter = map.bind("example.org", "/")
- assert adapter.match('/foo') == ('foo', {})
- assert adapter.match('/blub') == ('bar', {'bar': 'blub'})
- assert adapter.match('/he/he') == ('bar', {'bar': 'he', 'blub': 'he'})
+ assert adapter.match("/foo") == ("foo", {})
+ assert adapter.match("/blub") == ("bar", {"bar": "blub"})
+ assert adapter.match("/he/he") == ("bar", {"bar": "he", "blub": "he"})
- assert adapter.build('foo', {}) == '/foo'
- assert adapter.build('bar', {'bar': 'blub'}) == '/blub'
- assert adapter.build('bar', {'bar': 'blub', 'blub': 'bar'}) == '/blub/bar'
+ assert adapter.build("foo", {}) == "/foo"
+ assert adapter.build("bar", {"bar": "blub"}) == "/blub"
+ assert adapter.build("bar", {"bar": "blub", "blub": "bar"}) == "/blub/bar"
def test_path():
- map = r.Map([
- r.Rule('/', defaults={'name': 'FrontPage'}, endpoint='page'),
- r.Rule('/Special', endpoint='special'),
- r.Rule('/<int:year>', endpoint='year'),
- r.Rule('/<path:name>:foo', endpoint='foopage'),
- r.Rule('/<path:name>:<path:name2>', endpoint='twopage'),
- r.Rule('/<path:name>', endpoint='page'),
- r.Rule('/<path:name>/edit', endpoint='editpage'),
- r.Rule('/<path:name>/silly/<path:name2>', endpoint='sillypage'),
- r.Rule('/<path:name>/silly/<path:name2>/edit', endpoint='editsillypage'),
- r.Rule('/Talk:<path:name>', endpoint='talk'),
- r.Rule('/User:<username>', endpoint='user'),
- r.Rule('/User:<username>/<path:name>', endpoint='userpage'),
- r.Rule('/User:<username>/comment/<int:id>-<int:replyId>', endpoint='usercomment'),
- r.Rule('/Files/<path:file>', endpoint='files'),
- r.Rule('/<admin>/<manage>/<things>', endpoint='admin'),
- ])
- adapter = map.bind('example.org', '/')
-
- assert adapter.match('/') == ('page', {'name': 'FrontPage'})
- pytest.raises(r.RequestRedirect, lambda: adapter.match('/FrontPage'))
- assert adapter.match('/Special') == ('special', {})
- assert adapter.match('/2007') == ('year', {'year': 2007})
- assert adapter.match('/Some:foo') == ('foopage', {'name': 'Some'})
- assert adapter.match('/Some:bar') == ('twopage', {'name': 'Some', 'name2': 'bar'})
- assert adapter.match('/Some/Page') == ('page', {'name': 'Some/Page'})
- assert adapter.match('/Some/Page/edit') == ('editpage', {'name': 'Some/Page'})
- assert adapter.match('/Foo/silly/bar') == ('sillypage', {'name': 'Foo', 'name2': 'bar'})
- assert adapter.match(
- '/Foo/silly/bar/edit') == ('editsillypage', {'name': 'Foo', 'name2': 'bar'})
- assert adapter.match('/Talk:Foo/Bar') == ('talk', {'name': 'Foo/Bar'})
- assert adapter.match('/User:thomas') == ('user', {'username': 'thomas'})
- assert adapter.match('/User:thomas/projects/werkzeug') == \
- ('userpage', {'username': 'thomas', 'name': 'projects/werkzeug'})
- assert adapter.match('/User:thomas/comment/123-456') == \
- ('usercomment', {'username': 'thomas', 'id': 123, 'replyId': 456})
- assert adapter.match('/Files/downloads/werkzeug/0.2.zip') == \
- ('files', {'file': 'downloads/werkzeug/0.2.zip'})
- assert adapter.match('/Jerry/eats/cheese') == \
- ('admin', {'admin': 'Jerry', 'manage': 'eats', 'things': 'cheese'})
+ map = r.Map(
+ [
+ r.Rule("/", defaults={"name": "FrontPage"}, endpoint="page"),
+ r.Rule("/Special", endpoint="special"),
+ r.Rule("/<int:year>", endpoint="year"),
+ r.Rule("/<path:name>:foo", endpoint="foopage"),
+ r.Rule("/<path:name>:<path:name2>", endpoint="twopage"),
+ r.Rule("/<path:name>", endpoint="page"),
+ r.Rule("/<path:name>/edit", endpoint="editpage"),
+ r.Rule("/<path:name>/silly/<path:name2>", endpoint="sillypage"),
+ r.Rule("/<path:name>/silly/<path:name2>/edit", endpoint="editsillypage"),
+ r.Rule("/Talk:<path:name>", endpoint="talk"),
+ r.Rule("/User:<username>", endpoint="user"),
+ r.Rule("/User:<username>/<path:name>", endpoint="userpage"),
+ r.Rule(
+ "/User:<username>/comment/<int:id>-<int:replyId>",
+ endpoint="usercomment",
+ ),
+ r.Rule("/Files/<path:file>", endpoint="files"),
+ r.Rule("/<admin>/<manage>/<things>", endpoint="admin"),
+ ]
+ )
+ adapter = map.bind("example.org", "/")
+
+ assert adapter.match("/") == ("page", {"name": "FrontPage"})
+ pytest.raises(r.RequestRedirect, lambda: adapter.match("/FrontPage"))
+ assert adapter.match("/Special") == ("special", {})
+ assert adapter.match("/2007") == ("year", {"year": 2007})
+ assert adapter.match("/Some:foo") == ("foopage", {"name": "Some"})
+ assert adapter.match("/Some:bar") == ("twopage", {"name": "Some", "name2": "bar"})
+ assert adapter.match("/Some/Page") == ("page", {"name": "Some/Page"})
+ assert adapter.match("/Some/Page/edit") == ("editpage", {"name": "Some/Page"})
+ assert adapter.match("/Foo/silly/bar") == (
+ "sillypage",
+ {"name": "Foo", "name2": "bar"},
+ )
+ assert adapter.match("/Foo/silly/bar/edit") == (
+ "editsillypage",
+ {"name": "Foo", "name2": "bar"},
+ )
+ assert adapter.match("/Talk:Foo/Bar") == ("talk", {"name": "Foo/Bar"})
+ assert adapter.match("/User:thomas") == ("user", {"username": "thomas"})
+ assert adapter.match("/User:thomas/projects/werkzeug") == (
+ "userpage",
+ {"username": "thomas", "name": "projects/werkzeug"},
+ )
+ assert adapter.match("/User:thomas/comment/123-456") == (
+ "usercomment",
+ {"username": "thomas", "id": 123, "replyId": 456},
+ )
+ assert adapter.match("/Files/downloads/werkzeug/0.2.zip") == (
+ "files",
+ {"file": "downloads/werkzeug/0.2.zip"},
+ )
+ assert adapter.match("/Jerry/eats/cheese") == (
+ "admin",
+ {"admin": "Jerry", "manage": "eats", "things": "cheese"},
+ )
def test_dispatch():
- env = create_environ('/')
- map = r.Map([
- r.Rule('/', endpoint='root'),
- r.Rule('/foo/', endpoint='foo')
- ])
+ env = create_environ("/")
+ map = r.Map([r.Rule("/", endpoint="root"), r.Rule("/foo/", endpoint="foo")])
adapter = map.bind_to_environ(env)
raise_this = None
@@ -296,602 +323,625 @@ def test_dispatch():
if raise_this is not None:
raise raise_this
return Response(repr((endpoint, values)))
- dispatch = lambda p, q=False: Response.force_type(
- adapter.dispatch(view_func, p, catch_http_exceptions=q),
- env
- )
- assert dispatch('/').data == b"('root', {})"
- assert dispatch('/foo').status_code == 308
+ def dispatch(path, quiet=False):
+ return Response.force_type(
+ adapter.dispatch(view_func, path, catch_http_exceptions=quiet), env
+ )
+
+ assert dispatch("/").data == b"('root', {})"
+ assert dispatch("/foo").status_code == 308
raise_this = r.NotFound()
- pytest.raises(r.NotFound, lambda: dispatch('/bar'))
- assert dispatch('/bar', True).status_code == 404
+ pytest.raises(r.NotFound, lambda: dispatch("/bar"))
+ assert dispatch("/bar", True).status_code == 404
def test_http_host_before_server_name():
env = {
- 'HTTP_HOST': 'wiki.example.com',
- 'SERVER_NAME': 'web0.example.com',
- 'SERVER_PORT': '80',
- 'SCRIPT_NAME': '',
- 'PATH_INFO': '',
- 'REQUEST_METHOD': 'GET',
- 'wsgi.url_scheme': 'http'
+ "HTTP_HOST": "wiki.example.com",
+ "SERVER_NAME": "web0.example.com",
+ "SERVER_PORT": "80",
+ "SCRIPT_NAME": "",
+ "PATH_INFO": "",
+ "REQUEST_METHOD": "GET",
+ "wsgi.url_scheme": "http",
}
- map = r.Map([r.Rule('/', endpoint='index', subdomain='wiki')])
- adapter = map.bind_to_environ(env, server_name='example.com')
- assert adapter.match('/') == ('index', {})
- assert adapter.build('index', force_external=True) == 'http://wiki.example.com/'
- assert adapter.build('index') == '/'
+ map = r.Map([r.Rule("/", endpoint="index", subdomain="wiki")])
+ adapter = map.bind_to_environ(env, server_name="example.com")
+ assert adapter.match("/") == ("index", {})
+ assert adapter.build("index", force_external=True) == "http://wiki.example.com/"
+ assert adapter.build("index") == "/"
- env['HTTP_HOST'] = 'admin.example.com'
- adapter = map.bind_to_environ(env, server_name='example.com')
- assert adapter.build('index') == 'http://wiki.example.com/'
+ env["HTTP_HOST"] = "admin.example.com"
+ adapter = map.bind_to_environ(env, server_name="example.com")
+ assert adapter.build("index") == "http://wiki.example.com/"
def test_adapter_url_parameter_sorting():
- map = r.Map([r.Rule('/', endpoint='index')], sort_parameters=True,
- sort_key=lambda x: x[1])
- adapter = map.bind('localhost', '/')
- assert adapter.build('index', {'x': 20, 'y': 10, 'z': 30},
- force_external=True) == 'http://localhost/?y=10&x=20&z=30'
+ map = r.Map(
+ [r.Rule("/", endpoint="index")], sort_parameters=True, sort_key=lambda x: x[1]
+ )
+ adapter = map.bind("localhost", "/")
+ assert (
+ adapter.build("index", {"x": 20, "y": 10, "z": 30}, force_external=True)
+ == "http://localhost/?y=10&x=20&z=30"
+ )
def test_request_direct_charset_bug():
- map = r.Map([r.Rule(u'/öäü/')])
- adapter = map.bind('localhost', '/')
+ map = r.Map([r.Rule(u"/öäü/")])
+ adapter = map.bind("localhost", "/")
with pytest.raises(r.RequestRedirect) as excinfo:
- adapter.match(u'/öäü')
- assert excinfo.value.new_url == 'http://localhost/%C3%B6%C3%A4%C3%BC/'
+ adapter.match(u"/öäü")
+ assert excinfo.value.new_url == "http://localhost/%C3%B6%C3%A4%C3%BC/"
def test_request_redirect_default():
- map = r.Map([r.Rule(u'/foo', defaults={'bar': 42}),
- r.Rule(u'/foo/<int:bar>')])
- adapter = map.bind('localhost', '/')
+ map = r.Map([r.Rule(u"/foo", defaults={"bar": 42}), r.Rule(u"/foo/<int:bar>")])
+ adapter = map.bind("localhost", "/")
with pytest.raises(r.RequestRedirect) as excinfo:
- adapter.match(u'/foo/42')
- assert excinfo.value.new_url == 'http://localhost/foo'
+ adapter.match(u"/foo/42")
+ assert excinfo.value.new_url == "http://localhost/foo"
def test_request_redirect_default_subdomain():
- map = r.Map([r.Rule(u'/foo', defaults={'bar': 42}, subdomain='test'),
- r.Rule(u'/foo/<int:bar>', subdomain='other')])
- adapter = map.bind('localhost', '/', subdomain='other')
+ map = r.Map(
+ [
+ r.Rule(u"/foo", defaults={"bar": 42}, subdomain="test"),
+ r.Rule(u"/foo/<int:bar>", subdomain="other"),
+ ]
+ )
+ adapter = map.bind("localhost", "/", subdomain="other")
with pytest.raises(r.RequestRedirect) as excinfo:
- adapter.match(u'/foo/42')
- assert excinfo.value.new_url == 'http://test.localhost/foo'
+ adapter.match(u"/foo/42")
+ assert excinfo.value.new_url == "http://test.localhost/foo"
def test_adapter_match_return_rule():
- rule = r.Rule('/foo/', endpoint='foo')
+ rule = r.Rule("/foo/", endpoint="foo")
map = r.Map([rule])
- adapter = map.bind('localhost', '/')
- assert adapter.match('/foo/', return_rule=True) == (rule, {})
+ adapter = map.bind("localhost", "/")
+ assert adapter.match("/foo/", return_rule=True) == (rule, {})
def test_server_name_interpolation():
- server_name = 'example.invalid'
- map = r.Map([r.Rule('/', endpoint='index'),
- r.Rule('/', endpoint='alt', subdomain='alt')])
+ server_name = "example.invalid"
+ map = r.Map(
+ [r.Rule("/", endpoint="index"), r.Rule("/", endpoint="alt", subdomain="alt")]
+ )
- env = create_environ('/', 'http://%s/' % server_name)
+ env = create_environ("/", "http://%s/" % server_name)
adapter = map.bind_to_environ(env, server_name=server_name)
- assert adapter.match() == ('index', {})
+ assert adapter.match() == ("index", {})
- env = create_environ('/', 'http://alt.%s/' % server_name)
+ env = create_environ("/", "http://alt.%s/" % server_name)
adapter = map.bind_to_environ(env, server_name=server_name)
- assert adapter.match() == ('alt', {})
+ assert adapter.match() == ("alt", {})
- env = create_environ('/', 'http://%s/' % server_name)
- adapter = map.bind_to_environ(env, server_name='foo')
- assert adapter.subdomain == '<invalid>'
+ env = create_environ("/", "http://%s/" % server_name)
+ adapter = map.bind_to_environ(env, server_name="foo")
+ assert adapter.subdomain == "<invalid>"
def test_rule_emptying():
- rule = r.Rule('/foo', {'meh': 'muh'}, 'x', ['POST'],
- False, 'x', True, None)
+ rule = r.Rule("/foo", {"meh": "muh"}, "x", ["POST"], False, "x", True, None)
rule2 = rule.empty()
assert rule.__dict__ == rule2.__dict__
- rule.methods.add('GET')
+ rule.methods.add("GET")
assert rule.__dict__ != rule2.__dict__
- rule.methods.discard('GET')
- rule.defaults['meh'] = 'aha'
+ rule.methods.discard("GET")
+ rule.defaults["meh"] = "aha"
assert rule.__dict__ != rule2.__dict__
def test_rule_unhashable():
- rule = r.Rule('/foo', {'meh': 'muh'}, 'x', ['POST'],
- False, 'x', True, None)
+ rule = r.Rule("/foo", {"meh": "muh"}, "x", ["POST"], False, "x", True, None)
pytest.raises(TypeError, hash, rule)
def test_rule_templates():
- testcase = r.RuleTemplate([
- r.Submount(
- '/test/$app',
- [r.Rule('/foo/', endpoint='handle_foo'),
- r.Rule('/bar/', endpoint='handle_bar'),
- r.Rule('/baz/', endpoint='handle_baz')]),
- r.EndpointPrefix(
- '${app}',
- [r.Rule('/${app}-blah', endpoint='bar'),
- r.Rule('/${app}-meh', endpoint='baz')]),
- r.Subdomain(
- '$app',
- [r.Rule('/blah', endpoint='x_bar'),
- r.Rule('/meh', endpoint='x_baz')])
- ])
+ testcase = r.RuleTemplate(
+ [
+ r.Submount(
+ "/test/$app",
+ [
+ r.Rule("/foo/", endpoint="handle_foo"),
+ r.Rule("/bar/", endpoint="handle_bar"),
+ r.Rule("/baz/", endpoint="handle_baz"),
+ ],
+ ),
+ r.EndpointPrefix(
+ "${app}",
+ [
+ r.Rule("/${app}-blah", endpoint="bar"),
+ r.Rule("/${app}-meh", endpoint="baz"),
+ ],
+ ),
+ r.Subdomain(
+ "$app",
+ [r.Rule("/blah", endpoint="x_bar"), r.Rule("/meh", endpoint="x_baz")],
+ ),
+ ]
+ )
url_map = r.Map(
- [testcase(app='test1'), testcase(app='test2'), testcase(app='test3'), testcase(app='test4')
- ])
-
- out = sorted([(x.rule, x.subdomain, x.endpoint)
- for x in url_map.iter_rules()])
-
- assert out == ([
- ('/blah', 'test1', 'x_bar'),
- ('/blah', 'test2', 'x_bar'),
- ('/blah', 'test3', 'x_bar'),
- ('/blah', 'test4', 'x_bar'),
- ('/meh', 'test1', 'x_baz'),
- ('/meh', 'test2', 'x_baz'),
- ('/meh', 'test3', 'x_baz'),
- ('/meh', 'test4', 'x_baz'),
- ('/test/test1/bar/', '', 'handle_bar'),
- ('/test/test1/baz/', '', 'handle_baz'),
- ('/test/test1/foo/', '', 'handle_foo'),
- ('/test/test2/bar/', '', 'handle_bar'),
- ('/test/test2/baz/', '', 'handle_baz'),
- ('/test/test2/foo/', '', 'handle_foo'),
- ('/test/test3/bar/', '', 'handle_bar'),
- ('/test/test3/baz/', '', 'handle_baz'),
- ('/test/test3/foo/', '', 'handle_foo'),
- ('/test/test4/bar/', '', 'handle_bar'),
- ('/test/test4/baz/', '', 'handle_baz'),
- ('/test/test4/foo/', '', 'handle_foo'),
- ('/test1-blah', '', 'test1bar'),
- ('/test1-meh', '', 'test1baz'),
- ('/test2-blah', '', 'test2bar'),
- ('/test2-meh', '', 'test2baz'),
- ('/test3-blah', '', 'test3bar'),
- ('/test3-meh', '', 'test3baz'),
- ('/test4-blah', '', 'test4bar'),
- ('/test4-meh', '', 'test4baz')
- ])
+ [
+ testcase(app="test1"),
+ testcase(app="test2"),
+ testcase(app="test3"),
+ testcase(app="test4"),
+ ]
+ )
+
+ out = sorted([(x.rule, x.subdomain, x.endpoint) for x in url_map.iter_rules()])
+
+ assert out == [
+ ("/blah", "test1", "x_bar"),
+ ("/blah", "test2", "x_bar"),
+ ("/blah", "test3", "x_bar"),
+ ("/blah", "test4", "x_bar"),
+ ("/meh", "test1", "x_baz"),
+ ("/meh", "test2", "x_baz"),
+ ("/meh", "test3", "x_baz"),
+ ("/meh", "test4", "x_baz"),
+ ("/test/test1/bar/", "", "handle_bar"),
+ ("/test/test1/baz/", "", "handle_baz"),
+ ("/test/test1/foo/", "", "handle_foo"),
+ ("/test/test2/bar/", "", "handle_bar"),
+ ("/test/test2/baz/", "", "handle_baz"),
+ ("/test/test2/foo/", "", "handle_foo"),
+ ("/test/test3/bar/", "", "handle_bar"),
+ ("/test/test3/baz/", "", "handle_baz"),
+ ("/test/test3/foo/", "", "handle_foo"),
+ ("/test/test4/bar/", "", "handle_bar"),
+ ("/test/test4/baz/", "", "handle_baz"),
+ ("/test/test4/foo/", "", "handle_foo"),
+ ("/test1-blah", "", "test1bar"),
+ ("/test1-meh", "", "test1baz"),
+ ("/test2-blah", "", "test2bar"),
+ ("/test2-meh", "", "test2baz"),
+ ("/test3-blah", "", "test3bar"),
+ ("/test3-meh", "", "test3baz"),
+ ("/test4-blah", "", "test4bar"),
+ ("/test4-meh", "", "test4baz"),
+ ]
def test_non_string_parts():
- m = r.Map([
- r.Rule('/<foo>', endpoint='foo')
- ])
- a = m.bind('example.com')
- assert a.build('foo', {'foo': 42}) == '/42'
+ m = r.Map([r.Rule("/<foo>", endpoint="foo")])
+ a = m.bind("example.com")
+ assert a.build("foo", {"foo": 42}) == "/42"
def test_complex_routing_rules():
- m = r.Map([
- r.Rule('/', endpoint='index'),
- r.Rule('/<int:blub>', endpoint='an_int'),
- r.Rule('/<blub>', endpoint='a_string'),
- r.Rule('/foo/', endpoint='nested'),
- r.Rule('/foobar/', endpoint='nestedbar'),
- r.Rule('/foo/<path:testing>/', endpoint='nested_show'),
- r.Rule('/foo/<path:testing>/edit', endpoint='nested_edit'),
- r.Rule('/users/', endpoint='users', defaults={'page': 1}),
- r.Rule('/users/page/<int:page>', endpoint='users'),
- r.Rule('/foox', endpoint='foox'),
- r.Rule('/<path:bar>/<path:blub>', endpoint='barx_path_path')
- ])
- a = m.bind('example.com')
-
- assert a.match('/') == ('index', {})
- assert a.match('/42') == ('an_int', {'blub': 42})
- assert a.match('/blub') == ('a_string', {'blub': 'blub'})
- assert a.match('/foo/') == ('nested', {})
- assert a.match('/foobar/') == ('nestedbar', {})
- assert a.match('/foo/1/2/3/') == ('nested_show', {'testing': '1/2/3'})
- assert a.match('/foo/1/2/3/edit') == ('nested_edit', {'testing': '1/2/3'})
- assert a.match('/users/') == ('users', {'page': 1})
- assert a.match('/users/page/2') == ('users', {'page': 2})
- assert a.match('/foox') == ('foox', {})
- assert a.match('/1/2/3') == ('barx_path_path', {'bar': '1', 'blub': '2/3'})
-
- assert a.build('index') == '/'
- assert a.build('an_int', {'blub': 42}) == '/42'
- assert a.build('a_string', {'blub': 'test'}) == '/test'
- assert a.build('nested') == '/foo/'
- assert a.build('nestedbar') == '/foobar/'
- assert a.build('nested_show', {'testing': '1/2/3'}) == '/foo/1/2/3/'
- assert a.build('nested_edit', {'testing': '1/2/3'}) == '/foo/1/2/3/edit'
- assert a.build('users', {'page': 1}) == '/users/'
- assert a.build('users', {'page': 2}) == '/users/page/2'
- assert a.build('foox') == '/foox'
- assert a.build('barx_path_path', {'bar': '1', 'blub': '2/3'}) == '/1/2/3'
+ m = r.Map(
+ [
+ r.Rule("/", endpoint="index"),
+ r.Rule("/<int:blub>", endpoint="an_int"),
+ r.Rule("/<blub>", endpoint="a_string"),
+ r.Rule("/foo/", endpoint="nested"),
+ r.Rule("/foobar/", endpoint="nestedbar"),
+ r.Rule("/foo/<path:testing>/", endpoint="nested_show"),
+ r.Rule("/foo/<path:testing>/edit", endpoint="nested_edit"),
+ r.Rule("/users/", endpoint="users", defaults={"page": 1}),
+ r.Rule("/users/page/<int:page>", endpoint="users"),
+ r.Rule("/foox", endpoint="foox"),
+ r.Rule("/<path:bar>/<path:blub>", endpoint="barx_path_path"),
+ ]
+ )
+ a = m.bind("example.com")
+
+ assert a.match("/") == ("index", {})
+ assert a.match("/42") == ("an_int", {"blub": 42})
+ assert a.match("/blub") == ("a_string", {"blub": "blub"})
+ assert a.match("/foo/") == ("nested", {})
+ assert a.match("/foobar/") == ("nestedbar", {})
+ assert a.match("/foo/1/2/3/") == ("nested_show", {"testing": "1/2/3"})
+ assert a.match("/foo/1/2/3/edit") == ("nested_edit", {"testing": "1/2/3"})
+ assert a.match("/users/") == ("users", {"page": 1})
+ assert a.match("/users/page/2") == ("users", {"page": 2})
+ assert a.match("/foox") == ("foox", {})
+ assert a.match("/1/2/3") == ("barx_path_path", {"bar": "1", "blub": "2/3"})
+
+ assert a.build("index") == "/"
+ assert a.build("an_int", {"blub": 42}) == "/42"
+ assert a.build("a_string", {"blub": "test"}) == "/test"
+ assert a.build("nested") == "/foo/"
+ assert a.build("nestedbar") == "/foobar/"
+ assert a.build("nested_show", {"testing": "1/2/3"}) == "/foo/1/2/3/"
+ assert a.build("nested_edit", {"testing": "1/2/3"}) == "/foo/1/2/3/edit"
+ assert a.build("users", {"page": 1}) == "/users/"
+ assert a.build("users", {"page": 2}) == "/users/page/2"
+ assert a.build("foox") == "/foox"
+ assert a.build("barx_path_path", {"bar": "1", "blub": "2/3"}) == "/1/2/3"
def test_default_converters():
class MyMap(r.Map):
default_converters = r.Map.default_converters.copy()
- default_converters['foo'] = r.UnicodeConverter
+ default_converters["foo"] = r.UnicodeConverter
+
assert isinstance(r.Map.default_converters, ImmutableDict)
- m = MyMap([
- r.Rule('/a/<foo:a>', endpoint='a'),
- r.Rule('/b/<foo:b>', endpoint='b'),
- r.Rule('/c/<c>', endpoint='c')
- ], converters={'bar': r.UnicodeConverter})
- a = m.bind('example.org', '/')
- assert a.match('/a/1') == ('a', {'a': '1'})
- assert a.match('/b/2') == ('b', {'b': '2'})
- assert a.match('/c/3') == ('c', {'c': '3'})
- assert 'foo' not in r.Map.default_converters
+ m = MyMap(
+ [
+ r.Rule("/a/<foo:a>", endpoint="a"),
+ r.Rule("/b/<foo:b>", endpoint="b"),
+ r.Rule("/c/<c>", endpoint="c"),
+ ],
+ converters={"bar": r.UnicodeConverter},
+ )
+ a = m.bind("example.org", "/")
+ assert a.match("/a/1") == ("a", {"a": "1"})
+ assert a.match("/b/2") == ("b", {"b": "2"})
+ assert a.match("/c/3") == ("c", {"c": "3"})
+ assert "foo" not in r.Map.default_converters
def test_uuid_converter():
- m = r.Map([
- r.Rule('/a/<uuid:a_uuid>', endpoint='a')
- ])
- a = m.bind('example.org', '/')
- rooute, kwargs = a.match('/a/a8098c1a-f86e-11da-bd1a-00112444be1e')
- assert type(kwargs['a_uuid']) == uuid.UUID
+ m = r.Map([r.Rule("/a/<uuid:a_uuid>", endpoint="a")])
+ a = m.bind("example.org", "/")
+ rooute, kwargs = a.match("/a/a8098c1a-f86e-11da-bd1a-00112444be1e")
+ assert type(kwargs["a_uuid"]) == uuid.UUID
def test_converter_with_tuples():
- '''
+ """
Regression test for https://github.com/pallets/werkzeug/issues/709
- '''
- class TwoValueConverter(r.BaseConverter):
+ """
+ class TwoValueConverter(r.BaseConverter):
def __init__(self, *args, **kwargs):
super(TwoValueConverter, self).__init__(*args, **kwargs)
- self.regex = r'(\w\w+)/(\w\w+)'
+ self.regex = r"(\w\w+)/(\w\w+)"
def to_python(self, two_values):
- one, two = two_values.split('/')
+ one, two = two_values.split("/")
return one, two
def to_url(self, values):
return "%s/%s" % (values[0], values[1])
- map = r.Map([
- r.Rule('/<two:foo>/', endpoint='handler')
- ], converters={'two': TwoValueConverter})
- a = map.bind('example.org', '/')
- route, kwargs = a.match('/qwert/yuiop/')
- assert kwargs['foo'] == ('qwert', 'yuiop')
+ map = r.Map(
+ [r.Rule("/<two:foo>/", endpoint="handler")],
+ converters={"two": TwoValueConverter},
+ )
+ a = map.bind("example.org", "/")
+ route, kwargs = a.match("/qwert/yuiop/")
+ assert kwargs["foo"] == ("qwert", "yuiop")
def test_anyconverter():
- m = r.Map([
- r.Rule('/<any(a1, a2):a>', endpoint='no_dot'),
- r.Rule('/<any(a.1, a.2):a>', endpoint='yes_dot')
- ])
- a = m.bind('example.org', '/')
- assert a.match('/a1') == ('no_dot', {'a': 'a1'})
- assert a.match('/a2') == ('no_dot', {'a': 'a2'})
- assert a.match('/a.1') == ('yes_dot', {'a': 'a.1'})
- assert a.match('/a.2') == ('yes_dot', {'a': 'a.2'})
+ m = r.Map(
+ [
+ r.Rule("/<any(a1, a2):a>", endpoint="no_dot"),
+ r.Rule("/<any(a.1, a.2):a>", endpoint="yes_dot"),
+ ]
+ )
+ a = m.bind("example.org", "/")
+ assert a.match("/a1") == ("no_dot", {"a": "a1"})
+ assert a.match("/a2") == ("no_dot", {"a": "a2"})
+ assert a.match("/a.1") == ("yes_dot", {"a": "a.1"})
+ assert a.match("/a.2") == ("yes_dot", {"a": "a.2"})
def test_build_append_unknown():
- map = r.Map([
- r.Rule('/bar/<float:bazf>', endpoint='barf')
- ])
- adapter = map.bind('example.org', '/', subdomain='blah')
- assert adapter.build('barf', {'bazf': 0.815, 'bif': 1.0}) == \
- 'http://example.org/bar/0.815?bif=1.0'
- assert adapter.build('barf', {'bazf': 0.815, 'bif': 1.0},
- append_unknown=False) == 'http://example.org/bar/0.815'
+ map = r.Map([r.Rule("/bar/<float:bazf>", endpoint="barf")])
+ adapter = map.bind("example.org", "/", subdomain="blah")
+ assert (
+ adapter.build("barf", {"bazf": 0.815, "bif": 1.0})
+ == "http://example.org/bar/0.815?bif=1.0"
+ )
+ assert (
+ adapter.build("barf", {"bazf": 0.815, "bif": 1.0}, append_unknown=False)
+ == "http://example.org/bar/0.815"
+ )
def test_build_append_multiple():
- map = r.Map([
- r.Rule('/bar/<float:foo>', endpoint='endp')
- ])
- adapter = map.bind('example.org', '/', subdomain='subd')
- params = {'foo': 0.815, 'x': [1.0, 3.0], 'y': 2.0}
- a, b = adapter.build('endp', params).split('?')
- assert a == 'http://example.org/bar/0.815'
- assert set(b.split('&')) == set('y=2.0&x=1.0&x=3.0'.split('&'))
+ map = r.Map([r.Rule("/bar/<float:foo>", endpoint="endp")])
+ adapter = map.bind("example.org", "/", subdomain="subd")
+ params = {"foo": 0.815, "x": [1.0, 3.0], "y": 2.0}
+ a, b = adapter.build("endp", params).split("?")
+ assert a == "http://example.org/bar/0.815"
+ assert set(b.split("&")) == set("y=2.0&x=1.0&x=3.0".split("&"))
def test_build_append_multidict():
- map = r.Map([
- r.Rule('/bar/<float:foo>', endpoint='endp')
- ])
- adapter = map.bind('example.org', '/', subdomain='subd')
- params = MultiDict(
- (('foo', 0.815), ('x', 1.0), ('x', 3.0), ('y', 2.0)))
- a, b = adapter.build('endp', params).split('?')
- assert a == 'http://example.org/bar/0.815'
- assert set(b.split('&')) == set('y=2.0&x=1.0&x=3.0'.split('&'))
+ map = r.Map([r.Rule("/bar/<float:foo>", endpoint="endp")])
+ adapter = map.bind("example.org", "/", subdomain="subd")
+ params = MultiDict((("foo", 0.815), ("x", 1.0), ("x", 3.0), ("y", 2.0)))
+ a, b = adapter.build("endp", params).split("?")
+ assert a == "http://example.org/bar/0.815"
+ assert set(b.split("&")) == set("y=2.0&x=1.0&x=3.0".split("&"))
def test_build_drop_none():
- map = r.Map([
- r.Rule('/flob/<flub>', endpoint='endp')
- ])
- adapter = map.bind('', '/')
- params = {'flub': None, 'flop': None}
+ map = r.Map([r.Rule("/flob/<flub>", endpoint="endp")])
+ adapter = map.bind("", "/")
+ params = {"flub": None, "flop": None}
with pytest.raises(r.BuildError):
- x = adapter.build('endp', params)
+ x = adapter.build("endp", params)
assert not x
- params = {'flub': 'x', 'flop': None}
- url = adapter.build('endp', params)
- assert 'flop' not in url
+ params = {"flub": "x", "flop": None}
+ url = adapter.build("endp", params)
+ assert "flop" not in url
def test_method_fallback():
- map = r.Map([
- r.Rule('/', endpoint='index', methods=['GET']),
- r.Rule('/<name>', endpoint='hello_name', methods=['GET']),
- r.Rule('/select', endpoint='hello_select', methods=['POST']),
- r.Rule('/search_get', endpoint='search', methods=['GET']),
- r.Rule('/search_post', endpoint='search', methods=['POST'])
- ])
- adapter = map.bind('example.com')
- assert adapter.build('index') == '/'
- assert adapter.build('index', method='GET') == '/'
- assert adapter.build('hello_name', {'name': 'foo'}) == '/foo'
- assert adapter.build('hello_select') == '/select'
- assert adapter.build('hello_select', method='POST') == '/select'
- assert adapter.build('search') == '/search_get'
- assert adapter.build('search', method='GET') == '/search_get'
- assert adapter.build('search', method='POST') == '/search_post'
+ map = r.Map(
+ [
+ r.Rule("/", endpoint="index", methods=["GET"]),
+ r.Rule("/<name>", endpoint="hello_name", methods=["GET"]),
+ r.Rule("/select", endpoint="hello_select", methods=["POST"]),
+ r.Rule("/search_get", endpoint="search", methods=["GET"]),
+ r.Rule("/search_post", endpoint="search", methods=["POST"]),
+ ]
+ )
+ adapter = map.bind("example.com")
+ assert adapter.build("index") == "/"
+ assert adapter.build("index", method="GET") == "/"
+ assert adapter.build("hello_name", {"name": "foo"}) == "/foo"
+ assert adapter.build("hello_select") == "/select"
+ assert adapter.build("hello_select", method="POST") == "/select"
+ assert adapter.build("search") == "/search_get"
+ assert adapter.build("search", method="GET") == "/search_get"
+ assert adapter.build("search", method="POST") == "/search_post"
def test_implicit_head():
- url_map = r.Map([
- r.Rule('/get', methods=['GET'], endpoint='a'),
- r.Rule('/post', methods=['POST'], endpoint='b')
- ])
- adapter = url_map.bind('example.org')
- assert adapter.match('/get', method='HEAD') == ('a', {})
- pytest.raises(r.MethodNotAllowed, adapter.match,
- '/post', method='HEAD')
+ url_map = r.Map(
+ [
+ r.Rule("/get", methods=["GET"], endpoint="a"),
+ r.Rule("/post", methods=["POST"], endpoint="b"),
+ ]
+ )
+ adapter = url_map.bind("example.org")
+ assert adapter.match("/get", method="HEAD") == ("a", {})
+ pytest.raises(r.MethodNotAllowed, adapter.match, "/post", method="HEAD")
def test_pass_str_as_router_methods():
with pytest.raises(TypeError):
- r.Rule('/get', methods='GET')
+ r.Rule("/get", methods="GET")
def test_protocol_joining_bug():
- m = r.Map([r.Rule('/<foo>', endpoint='x')])
- a = m.bind('example.org')
- assert a.build('x', {'foo': 'x:y'}) == '/x:y'
- assert a.build('x', {'foo': 'x:y'}, force_external=True) == \
- 'http://example.org/x:y'
+ m = r.Map([r.Rule("/<foo>", endpoint="x")])
+ a = m.bind("example.org")
+ assert a.build("x", {"foo": "x:y"}) == "/x:y"
+ assert a.build("x", {"foo": "x:y"}, force_external=True) == "http://example.org/x:y"
def test_allowed_methods_querying():
- m = r.Map([r.Rule('/<foo>', methods=['GET', 'HEAD']),
- r.Rule('/foo', methods=['POST'])])
- a = m.bind('example.org')
- assert sorted(a.allowed_methods('/foo')) == ['GET', 'HEAD', 'POST']
+ m = r.Map(
+ [r.Rule("/<foo>", methods=["GET", "HEAD"]), r.Rule("/foo", methods=["POST"])]
+ )
+ a = m.bind("example.org")
+ assert sorted(a.allowed_methods("/foo")) == ["GET", "HEAD", "POST"]
def test_external_building_with_port():
- map = r.Map([
- r.Rule('/', endpoint='index'),
- ])
- adapter = map.bind('example.org:5000', '/')
- built_url = adapter.build('index', {}, force_external=True)
- assert built_url == 'http://example.org:5000/', built_url
+ map = r.Map([r.Rule("/", endpoint="index")])
+ adapter = map.bind("example.org:5000", "/")
+ built_url = adapter.build("index", {}, force_external=True)
+ assert built_url == "http://example.org:5000/", built_url
def test_external_building_with_port_bind_to_environ():
- map = r.Map([
- r.Rule('/', endpoint='index'),
- ])
+ map = r.Map([r.Rule("/", endpoint="index")])
adapter = map.bind_to_environ(
- create_environ('/', 'http://example.org:5000/'),
- server_name="example.org:5000"
+ create_environ("/", "http://example.org:5000/"), server_name="example.org:5000"
)
- built_url = adapter.build('index', {}, force_external=True)
- assert built_url == 'http://example.org:5000/', built_url
+ built_url = adapter.build("index", {}, force_external=True)
+ assert built_url == "http://example.org:5000/", built_url
def test_external_building_with_port_bind_to_environ_wrong_servername():
- map = r.Map([
- r.Rule('/', endpoint='index'),
- ])
- environ = create_environ('/', 'http://example.org:5000/')
+ map = r.Map([r.Rule("/", endpoint="index")])
+ environ = create_environ("/", "http://example.org:5000/")
adapter = map.bind_to_environ(environ, server_name="example.org")
- assert adapter.subdomain == '<invalid>'
+ assert adapter.subdomain == "<invalid>"
def test_converter_parser():
- args, kwargs = r.parse_converter_args(u'test, a=1, b=3.0')
+ args, kwargs = r.parse_converter_args(u"test, a=1, b=3.0")
- assert args == ('test',)
- assert kwargs == {'a': 1, 'b': 3.0}
+ assert args == ("test",)
+ assert kwargs == {"a": 1, "b": 3.0}
- args, kwargs = r.parse_converter_args('')
+ args, kwargs = r.parse_converter_args("")
assert not args and not kwargs
- args, kwargs = r.parse_converter_args('a, b, c,')
- assert args == ('a', 'b', 'c')
+ args, kwargs = r.parse_converter_args("a, b, c,")
+ assert args == ("a", "b", "c")
assert not kwargs
- args, kwargs = r.parse_converter_args('True, False, None')
+ args, kwargs = r.parse_converter_args("True, False, None")
assert args == (True, False, None)
args, kwargs = r.parse_converter_args('"foo", u"bar"')
- assert args == ('foo', 'bar')
+ assert args == ("foo", "bar")
def test_alias_redirects():
- m = r.Map([
- r.Rule('/', endpoint='index'),
- r.Rule('/index.html', endpoint='index', alias=True),
- r.Rule('/users/', defaults={'page': 1}, endpoint='users'),
- r.Rule('/users/index.html', defaults={'page': 1}, alias=True,
- endpoint='users'),
- r.Rule('/users/page/<int:page>', endpoint='users'),
- r.Rule('/users/page-<int:page>.html', alias=True, endpoint='users'),
- ])
- a = m.bind('example.com')
+ m = r.Map(
+ [
+ r.Rule("/", endpoint="index"),
+ r.Rule("/index.html", endpoint="index", alias=True),
+ r.Rule("/users/", defaults={"page": 1}, endpoint="users"),
+ r.Rule(
+ "/users/index.html", defaults={"page": 1}, alias=True, endpoint="users"
+ ),
+ r.Rule("/users/page/<int:page>", endpoint="users"),
+ r.Rule("/users/page-<int:page>.html", alias=True, endpoint="users"),
+ ]
+ )
+ a = m.bind("example.com")
def ensure_redirect(path, new_url, args=None):
with pytest.raises(r.RequestRedirect) as excinfo:
a.match(path, query_args=args)
- assert excinfo.value.new_url == 'http://example.com' + new_url
+ assert excinfo.value.new_url == "http://example.com" + new_url
- ensure_redirect('/index.html', '/')
- ensure_redirect('/users/index.html', '/users/')
- ensure_redirect('/users/page-2.html', '/users/page/2')
- ensure_redirect('/users/page-1.html', '/users/')
- ensure_redirect('/users/page-1.html', '/users/?foo=bar', {'foo': 'bar'})
+ ensure_redirect("/index.html", "/")
+ ensure_redirect("/users/index.html", "/users/")
+ ensure_redirect("/users/page-2.html", "/users/page/2")
+ ensure_redirect("/users/page-1.html", "/users/")
+ ensure_redirect("/users/page-1.html", "/users/?foo=bar", {"foo": "bar"})
- assert a.build('index') == '/'
- assert a.build('users', {'page': 1}) == '/users/'
- assert a.build('users', {'page': 2}) == '/users/page/2'
+ assert a.build("index") == "/"
+ assert a.build("users", {"page": 1}) == "/users/"
+ assert a.build("users", {"page": 2}) == "/users/page/2"
-@pytest.mark.parametrize('prefix', ('', '/aaa'))
+@pytest.mark.parametrize("prefix", ("", "/aaa"))
def test_double_defaults(prefix):
- m = r.Map([
- r.Rule(prefix + '/', defaults={'foo': 1, 'bar': False}, endpoint='x'),
- r.Rule(prefix + '/<int:foo>', defaults={'bar': False}, endpoint='x'),
- r.Rule(prefix + '/bar/', defaults={'foo': 1, 'bar': True}, endpoint='x'),
- r.Rule(prefix + '/bar/<int:foo>', defaults={'bar': True}, endpoint='x')
- ])
- a = m.bind('example.com')
-
- assert a.match(prefix + '/') == ('x', {'foo': 1, 'bar': False})
- assert a.match(prefix + '/2') == ('x', {'foo': 2, 'bar': False})
- assert a.match(prefix + '/bar/') == ('x', {'foo': 1, 'bar': True})
- assert a.match(prefix + '/bar/2') == ('x', {'foo': 2, 'bar': True})
-
- assert a.build('x', {'foo': 1, 'bar': False}) == prefix + '/'
- assert a.build('x', {'foo': 2, 'bar': False}) == prefix + '/2'
- assert a.build('x', {'bar': False}) == prefix + '/'
- assert a.build('x', {'foo': 1, 'bar': True}) == prefix + '/bar/'
- assert a.build('x', {'foo': 2, 'bar': True}) == prefix + '/bar/2'
- assert a.build('x', {'bar': True}) == prefix + '/bar/'
+ m = r.Map(
+ [
+ r.Rule(prefix + "/", defaults={"foo": 1, "bar": False}, endpoint="x"),
+ r.Rule(prefix + "/<int:foo>", defaults={"bar": False}, endpoint="x"),
+ r.Rule(prefix + "/bar/", defaults={"foo": 1, "bar": True}, endpoint="x"),
+ r.Rule(prefix + "/bar/<int:foo>", defaults={"bar": True}, endpoint="x"),
+ ]
+ )
+ a = m.bind("example.com")
+
+ assert a.match(prefix + "/") == ("x", {"foo": 1, "bar": False})
+ assert a.match(prefix + "/2") == ("x", {"foo": 2, "bar": False})
+ assert a.match(prefix + "/bar/") == ("x", {"foo": 1, "bar": True})
+ assert a.match(prefix + "/bar/2") == ("x", {"foo": 2, "bar": True})
+
+ assert a.build("x", {"foo": 1, "bar": False}) == prefix + "/"
+ assert a.build("x", {"foo": 2, "bar": False}) == prefix + "/2"
+ assert a.build("x", {"bar": False}) == prefix + "/"
+ assert a.build("x", {"foo": 1, "bar": True}) == prefix + "/bar/"
+ assert a.build("x", {"foo": 2, "bar": True}) == prefix + "/bar/2"
+ assert a.build("x", {"bar": True}) == prefix + "/bar/"
def test_host_matching():
- m = r.Map([
- r.Rule('/', endpoint='index', host='www.<domain>'),
- r.Rule('/', endpoint='files', host='files.<domain>'),
- r.Rule('/foo/', defaults={'page': 1}, host='www.<domain>', endpoint='x'),
- r.Rule('/<int:page>', host='files.<domain>', endpoint='x')
- ], host_matching=True)
+ m = r.Map(
+ [
+ r.Rule("/", endpoint="index", host="www.<domain>"),
+ r.Rule("/", endpoint="files", host="files.<domain>"),
+ r.Rule("/foo/", defaults={"page": 1}, host="www.<domain>", endpoint="x"),
+ r.Rule("/<int:page>", host="files.<domain>", endpoint="x"),
+ ],
+ host_matching=True,
+ )
- a = m.bind('www.example.com')
- assert a.match('/') == ('index', {'domain': 'example.com'})
- assert a.match('/foo/') == ('x', {'domain': 'example.com', 'page': 1})
+ a = m.bind("www.example.com")
+ assert a.match("/") == ("index", {"domain": "example.com"})
+ assert a.match("/foo/") == ("x", {"domain": "example.com", "page": 1})
with pytest.raises(r.RequestRedirect) as excinfo:
- a.match('/foo')
- assert excinfo.value.new_url == 'http://www.example.com/foo/'
+ a.match("/foo")
+ assert excinfo.value.new_url == "http://www.example.com/foo/"
- a = m.bind('files.example.com')
- assert a.match('/') == ('files', {'domain': 'example.com'})
- assert a.match('/2') == ('x', {'domain': 'example.com', 'page': 2})
+ a = m.bind("files.example.com")
+ assert a.match("/") == ("files", {"domain": "example.com"})
+ assert a.match("/2") == ("x", {"domain": "example.com", "page": 2})
with pytest.raises(r.RequestRedirect) as excinfo:
- a.match('/1')
- assert excinfo.value.new_url == 'http://www.example.com/foo/'
+ a.match("/1")
+ assert excinfo.value.new_url == "http://www.example.com/foo/"
def test_host_matching_building():
- m = r.Map([
- r.Rule('/', endpoint='index', host='www.domain.com'),
- r.Rule('/', endpoint='foo', host='my.domain.com')
- ], host_matching=True)
+ m = r.Map(
+ [
+ r.Rule("/", endpoint="index", host="www.domain.com"),
+ r.Rule("/", endpoint="foo", host="my.domain.com"),
+ ],
+ host_matching=True,
+ )
- www = m.bind('www.domain.com')
- assert www.match('/') == ('index', {})
- assert www.build('index') == '/'
- assert www.build('foo') == 'http://my.domain.com/'
+ www = m.bind("www.domain.com")
+ assert www.match("/") == ("index", {})
+ assert www.build("index") == "/"
+ assert www.build("foo") == "http://my.domain.com/"
- my = m.bind('my.domain.com')
- assert my.match('/') == ('foo', {})
- assert my.build('foo') == '/'
- assert my.build('index') == 'http://www.domain.com/'
+ my = m.bind("my.domain.com")
+ assert my.match("/") == ("foo", {})
+ assert my.build("foo") == "/"
+ assert my.build("index") == "http://www.domain.com/"
def test_server_name_casing():
- m = r.Map([
- r.Rule('/', endpoint='index', subdomain='foo')
- ])
+ m = r.Map([r.Rule("/", endpoint="index", subdomain="foo")])
env = create_environ()
- env['SERVER_NAME'] = env['HTTP_HOST'] = 'FOO.EXAMPLE.COM'
- a = m.bind_to_environ(env, server_name='example.com')
- assert a.match('/') == ('index', {})
+ env["SERVER_NAME"] = env["HTTP_HOST"] = "FOO.EXAMPLE.COM"
+ a = m.bind_to_environ(env, server_name="example.com")
+ assert a.match("/") == ("index", {})
env = create_environ()
- env['SERVER_NAME'] = '127.0.0.1'
- env['SERVER_PORT'] = '5000'
- del env['HTTP_HOST']
- a = m.bind_to_environ(env, server_name='example.com')
+ env["SERVER_NAME"] = "127.0.0.1"
+ env["SERVER_PORT"] = "5000"
+ del env["HTTP_HOST"]
+ a = m.bind_to_environ(env, server_name="example.com")
with pytest.raises(r.NotFound):
a.match()
def test_redirect_request_exception_code():
- exc = r.RequestRedirect('http://www.google.com/')
+ exc = r.RequestRedirect("http://www.google.com/")
exc.code = 307
env = create_environ()
strict_eq(exc.get_response(env).status_code, exc.code)
def test_redirect_path_quoting():
- url_map = r.Map([
- r.Rule('/<category>', defaults={'page': 1}, endpoint='category'),
- r.Rule('/<category>/page/<int:page>', endpoint='category')
- ])
- adapter = url_map.bind('example.com')
+ url_map = r.Map(
+ [
+ r.Rule("/<category>", defaults={"page": 1}, endpoint="category"),
+ r.Rule("/<category>/page/<int:page>", endpoint="category"),
+ ]
+ )
+ adapter = url_map.bind("example.com")
with pytest.raises(r.RequestRedirect) as excinfo:
- adapter.match('/foo bar/page/1')
+ adapter.match("/foo bar/page/1")
response = excinfo.value.get_response({})
- strict_eq(response.headers['location'],
- u'http://example.com/foo%20bar')
+ strict_eq(response.headers["location"], u"http://example.com/foo%20bar")
def test_unicode_rules():
- m = r.Map([
- r.Rule(u'/войти/', endpoint='enter'),
- r.Rule(u'/foo+bar/', endpoint='foobar')
- ])
- a = m.bind(u'☃.example.com')
+ m = r.Map(
+ [r.Rule(u"/войти/", endpoint="enter"), r.Rule(u"/foo+bar/", endpoint="foobar")]
+ )
+ a = m.bind(u"☃.example.com")
with pytest.raises(r.RequestRedirect) as excinfo:
- a.match(u'/войти')
- strict_eq(excinfo.value.new_url,
- 'http://xn--n3h.example.com/%D0%B2%D0%BE%D0%B9%D1%82%D0%B8/')
+ a.match(u"/войти")
+ strict_eq(
+ excinfo.value.new_url,
+ "http://xn--n3h.example.com/%D0%B2%D0%BE%D0%B9%D1%82%D0%B8/",
+ )
- endpoint, values = a.match(u'/войти/')
- strict_eq(endpoint, 'enter')
+ endpoint, values = a.match(u"/войти/")
+ strict_eq(endpoint, "enter")
strict_eq(values, {})
with pytest.raises(r.RequestRedirect) as excinfo:
- a.match(u'/foo+bar')
- strict_eq(excinfo.value.new_url, 'http://xn--n3h.example.com/foo+bar/')
+ a.match(u"/foo+bar")
+ strict_eq(excinfo.value.new_url, "http://xn--n3h.example.com/foo+bar/")
- endpoint, values = a.match(u'/foo+bar/')
- strict_eq(endpoint, 'foobar')
+ endpoint, values = a.match(u"/foo+bar/")
+ strict_eq(endpoint, "foobar")
strict_eq(values, {})
- url = a.build('enter', {}, force_external=True)
- strict_eq(url, 'http://xn--n3h.example.com/%D0%B2%D0%BE%D0%B9%D1%82%D0%B8/')
+ url = a.build("enter", {}, force_external=True)
+ strict_eq(url, "http://xn--n3h.example.com/%D0%B2%D0%BE%D0%B9%D1%82%D0%B8/")
- url = a.build('foobar', {}, force_external=True)
- strict_eq(url, 'http://xn--n3h.example.com/foo+bar/')
+ url = a.build("foobar", {}, force_external=True)
+ strict_eq(url, "http://xn--n3h.example.com/foo+bar/")
def test_empty_path_info():
- m = r.Map([
- r.Rule("/", endpoint="index"),
- ])
+ m = r.Map([r.Rule("/", endpoint="index")])
b = m.bind("example.com", script_name="/approot")
with pytest.raises(r.RequestRedirect) as excinfo:
@@ -905,29 +955,24 @@ def test_empty_path_info():
def test_both_bind_and_match_path_info_are_none():
- m = r.Map([r.Rule(u'/', endpoint='index')])
- ma = m.bind('example.org')
- strict_eq(ma.match(), ('index', {}))
+ m = r.Map([r.Rule(u"/", endpoint="index")])
+ ma = m.bind("example.org")
+ strict_eq(ma.match(), ("index", {}))
def test_map_repr():
- m = r.Map([
- r.Rule(u'/wat', endpoint='enter'),
- r.Rule(u'/woop', endpoint='foobar')
- ])
+ m = r.Map([r.Rule(u"/wat", endpoint="enter"), r.Rule(u"/woop", endpoint="foobar")])
rv = repr(m)
- strict_eq(rv,
- "Map([<Rule '/woop' -> foobar>, <Rule '/wat' -> enter>])")
+ strict_eq(rv, "Map([<Rule '/woop' -> foobar>, <Rule '/wat' -> enter>])")
def test_empty_subclass_rules_with_custom_kwargs():
class CustomRule(r.Rule):
-
def __init__(self, string=None, custom=None, *args, **kwargs):
self.custom = custom
super(CustomRule, self).__init__(string, *args, **kwargs)
- rule1 = CustomRule(u'/foo', endpoint='bar')
+ rule1 = CustomRule(u"/foo", endpoint="bar")
try:
rule2 = rule1.empty()
assert rule1.rule == rule2.rule
@@ -936,82 +981,79 @@ def test_empty_subclass_rules_with_custom_kwargs():
def test_finding_closest_match_by_endpoint():
- m = r.Map([
- r.Rule(u'/foo/', endpoint='users.here'),
- r.Rule(u'/wat/', endpoint='admin.users'),
- r.Rule(u'/woop', endpoint='foo.users'),
- ])
- adapter = m.bind('example.com')
- assert r.BuildError('admin.user', None, None, adapter).suggested.endpoint \
- == 'admin.users'
+ m = r.Map(
+ [
+ r.Rule(u"/foo/", endpoint="users.here"),
+ r.Rule(u"/wat/", endpoint="admin.users"),
+ r.Rule(u"/woop", endpoint="foo.users"),
+ ]
+ )
+ adapter = m.bind("example.com")
+ assert (
+ r.BuildError("admin.user", None, None, adapter).suggested.endpoint
+ == "admin.users"
+ )
def test_finding_closest_match_by_values():
- rule_id = r.Rule(u'/user/id/<id>/', endpoint='users')
- rule_slug = r.Rule(u'/user/<slug>/', endpoint='users')
- rule_random = r.Rule(u'/user/emails/<email>/', endpoint='users')
+ rule_id = r.Rule(u"/user/id/<id>/", endpoint="users")
+ rule_slug = r.Rule(u"/user/<slug>/", endpoint="users")
+ rule_random = r.Rule(u"/user/emails/<email>/", endpoint="users")
m = r.Map([rule_id, rule_slug, rule_random])
- adapter = m.bind('example.com')
- assert r.BuildError('x', {'slug': ''}, None, adapter).suggested == \
- rule_slug
+ adapter = m.bind("example.com")
+ assert r.BuildError("x", {"slug": ""}, None, adapter).suggested == rule_slug
def test_finding_closest_match_by_method():
- post = r.Rule(u'/post/', endpoint='foobar', methods=['POST'])
- get = r.Rule(u'/get/', endpoint='foobar', methods=['GET'])
- put = r.Rule(u'/put/', endpoint='foobar', methods=['PUT'])
+ post = r.Rule(u"/post/", endpoint="foobar", methods=["POST"])
+ get = r.Rule(u"/get/", endpoint="foobar", methods=["GET"])
+ put = r.Rule(u"/put/", endpoint="foobar", methods=["PUT"])
m = r.Map([post, get, put])
- adapter = m.bind('example.com')
- assert r.BuildError('invalid', {}, 'POST', adapter).suggested == post
- assert r.BuildError('invalid', {}, 'GET', adapter).suggested == get
- assert r.BuildError('invalid', {}, 'PUT', adapter).suggested == put
+ adapter = m.bind("example.com")
+ assert r.BuildError("invalid", {}, "POST", adapter).suggested == post
+ assert r.BuildError("invalid", {}, "GET", adapter).suggested == get
+ assert r.BuildError("invalid", {}, "PUT", adapter).suggested == put
def test_finding_closest_match_when_none_exist():
m = r.Map([])
- assert not r.BuildError('invalid', {}, None, m.bind('test.com')).suggested
+ assert not r.BuildError("invalid", {}, None, m.bind("test.com")).suggested
def test_error_message_without_suggested_rule():
- m = r.Map([
- r.Rule(u'/foo/', endpoint='world', methods=['GET']),
- ])
- adapter = m.bind('example.com')
+ m = r.Map([r.Rule(u"/foo/", endpoint="world", methods=["GET"])])
+ adapter = m.bind("example.com")
with pytest.raises(r.BuildError) as excinfo:
- adapter.build('urks')
- assert str(excinfo.value).startswith(
- "Could not build url for endpoint 'urks'."
- )
+ adapter.build("urks")
+ assert str(excinfo.value).startswith("Could not build url for endpoint 'urks'.")
with pytest.raises(r.BuildError) as excinfo:
- adapter.build('world', method='POST')
+ adapter.build("world", method="POST")
assert str(excinfo.value).startswith(
"Could not build url for endpoint 'world' ('POST')."
)
with pytest.raises(r.BuildError) as excinfo:
- adapter.build('urks', values={'user_id': 5})
+ adapter.build("urks", values={"user_id": 5})
assert str(excinfo.value).startswith(
"Could not build url for endpoint 'urks' with values ['user_id']."
)
def test_error_message_suggestion():
- m = r.Map([
- r.Rule(u'/foo/<id>/', endpoint='world', methods=['GET']),
- ])
- adapter = m.bind('example.com')
+ m = r.Map([r.Rule(u"/foo/<id>/", endpoint="world", methods=["GET"])])
+ adapter = m.bind("example.com")
with pytest.raises(r.BuildError) as excinfo:
- adapter.build('helloworld')
+ adapter.build("helloworld")
assert "Did you mean 'world' instead?" in str(excinfo.value)
with pytest.raises(r.BuildError) as excinfo:
- adapter.build('world')
+ adapter.build("world")
assert "Did you forget to specify values ['id']?" in str(excinfo.value)
assert "Did you mean to use methods" not in str(excinfo.value)
with pytest.raises(r.BuildError) as excinfo:
- adapter.build('world', {'id': 2}, method='POST')
+ adapter.build("world", {"id": 2}, method="POST")
assert "Did you mean to use methods ['GET', 'HEAD']?" in str(excinfo.value)
diff --git a/tests/test_security.py b/tests/test_security.py
index e30a0542..540bba92 100644
--- a/tests/test_security.py
+++ b/tests/test_security.py
@@ -10,79 +10,85 @@
"""
import os
import posixpath
+
import pytest
-from werkzeug.security import check_password_hash, generate_password_hash, \
- safe_join, pbkdf2_hex, safe_str_cmp
+from werkzeug.security import check_password_hash
+from werkzeug.security import generate_password_hash
+from werkzeug.security import pbkdf2_hex
+from werkzeug.security import safe_join
+from werkzeug.security import safe_str_cmp
def test_safe_str_cmp():
- assert safe_str_cmp('a', 'a') is True
- assert safe_str_cmp(b'a', u'a') is True
- assert safe_str_cmp('a', 'b') is False
- assert safe_str_cmp(b'aaa', 'aa') is False
- assert safe_str_cmp(b'aaa', 'bbb') is False
- assert safe_str_cmp(b'aaa', u'aaa') is True
- assert safe_str_cmp(u'aaa', u'aaa') is True
+ assert safe_str_cmp("a", "a") is True
+ assert safe_str_cmp(b"a", u"a") is True
+ assert safe_str_cmp("a", "b") is False
+ assert safe_str_cmp(b"aaa", "aa") is False
+ assert safe_str_cmp(b"aaa", "bbb") is False
+ assert safe_str_cmp(b"aaa", u"aaa") is True
+ assert safe_str_cmp(u"aaa", u"aaa") is True
def test_safe_str_cmp_no_builtin():
import werkzeug.security as sec
+
prev_value = sec._builtin_safe_str_cmp
sec._builtin_safe_str_cmp = None
- assert safe_str_cmp('a', 'ab') is False
+ assert safe_str_cmp("a", "ab") is False
- assert safe_str_cmp('str', 'str') is True
- assert safe_str_cmp('str1', 'str2') is False
+ assert safe_str_cmp("str", "str") is True
+ assert safe_str_cmp("str1", "str2") is False
sec._builtin_safe_str_cmp = prev_value
def test_password_hashing():
- hash0 = generate_password_hash('default')
- assert check_password_hash(hash0, 'default')
- assert hash0.startswith('pbkdf2:sha256:150000$')
+ hash0 = generate_password_hash("default")
+ assert check_password_hash(hash0, "default")
+ assert hash0.startswith("pbkdf2:sha256:150000$")
- hash1 = generate_password_hash('default', 'sha1')
- hash2 = generate_password_hash(u'default', method='sha1')
+ hash1 = generate_password_hash("default", "sha1")
+ hash2 = generate_password_hash(u"default", method="sha1")
assert hash1 != hash2
- assert check_password_hash(hash1, 'default')
- assert check_password_hash(hash2, 'default')
- assert hash1.startswith('sha1$')
- assert hash2.startswith('sha1$')
+ assert check_password_hash(hash1, "default")
+ assert check_password_hash(hash2, "default")
+ assert hash1.startswith("sha1$")
+ assert hash2.startswith("sha1$")
with pytest.raises(ValueError):
- check_password_hash('$made$up$', 'default')
+ check_password_hash("$made$up$", "default")
with pytest.raises(ValueError):
- generate_password_hash('default', 'sha1', salt_length=0)
+ generate_password_hash("default", "sha1", salt_length=0)
- fakehash = generate_password_hash('default', method='plain')
- assert fakehash == 'plain$$default'
- assert check_password_hash(fakehash, 'default')
+ fakehash = generate_password_hash("default", method="plain")
+ assert fakehash == "plain$$default"
+ assert check_password_hash(fakehash, "default")
- mhash = generate_password_hash(u'default', method='md5')
- assert mhash.startswith('md5$')
- assert check_password_hash(mhash, 'default')
+ mhash = generate_password_hash(u"default", method="md5")
+ assert mhash.startswith("md5$")
+ assert check_password_hash(mhash, "default")
- legacy = 'md5$$c21f969b5f03d33d43e04f8f136e7682'
- assert check_password_hash(legacy, 'default')
+ legacy = "md5$$c21f969b5f03d33d43e04f8f136e7682"
+ assert check_password_hash(legacy, "default")
- legacy = u'md5$$c21f969b5f03d33d43e04f8f136e7682'
- assert check_password_hash(legacy, 'default')
+ legacy = u"md5$$c21f969b5f03d33d43e04f8f136e7682"
+ assert check_password_hash(legacy, "default")
def test_safe_join():
- assert safe_join('foo', 'bar/baz') == posixpath.join('foo', 'bar/baz')
- assert safe_join('foo', '../bar/baz') is None
- if os.name == 'nt':
- assert safe_join('foo', 'foo\\bar') is None
+ assert safe_join("foo", "bar/baz") == posixpath.join("foo", "bar/baz")
+ assert safe_join("foo", "../bar/baz") is None
+ if os.name == "nt":
+ assert safe_join("foo", "foo\\bar") is None
def test_safe_join_os_sep():
import werkzeug.security as sec
+
prev_value = sec._os_alt_seps
- sec._os_alt_seps = '*'
- assert safe_join('foo', 'bar/baz*') is None
+ sec._os_alt_seps = "*"
+ assert safe_join("foo", "bar/baz*") is None
sec._os_alt_steps = prev_value
@@ -96,41 +102,107 @@ def test_pbkdf2():
# Assumes default keylen is 20
# check('password', 'salt', 1, None,
# '0c60c80f961f0e71f3a9b524af6012062fe037a6')
- check('password', 'salt', 1, 20, 'sha1',
- '0c60c80f961f0e71f3a9b524af6012062fe037a6')
- check('password', 'salt', 2, 20, 'sha1',
- 'ea6c014dc72d6f8ccd1ed92ace1d41f0d8de8957')
- check('password', 'salt', 4096, 20, 'sha1',
- '4b007901b765489abead49d926f721d065a429c1')
- check('passwordPASSWORDpassword', 'saltSALTsaltSALTsaltSALTsaltSALTsalt',
- 4096, 25, 'sha1', '3d2eec4fe41c849b80c8d83662c0e44a8b291a964cf2f07038')
- check('pass\x00word', 'sa\x00lt', 4096, 16, 'sha1',
- '56fa6aa75548099dcc37d7f03425e0c3')
+ check("password", "salt", 1, 20, "sha1", "0c60c80f961f0e71f3a9b524af6012062fe037a6")
+ check("password", "salt", 2, 20, "sha1", "ea6c014dc72d6f8ccd1ed92ace1d41f0d8de8957")
+ check(
+ "password", "salt", 4096, 20, "sha1", "4b007901b765489abead49d926f721d065a429c1"
+ )
+ check(
+ "passwordPASSWORDpassword",
+ "saltSALTsaltSALTsaltSALTsaltSALTsalt",
+ 4096,
+ 25,
+ "sha1",
+ "3d2eec4fe41c849b80c8d83662c0e44a8b291a964cf2f07038",
+ )
+ check(
+ "pass\x00word", "sa\x00lt", 4096, 16, "sha1", "56fa6aa75548099dcc37d7f03425e0c3"
+ )
# PBKDF2-HMAC-SHA256 test vectors
- check('password', 'salt', 1, 32, 'sha256',
- '120fb6cffcf8b32c43e7225256c4f837a86548c92ccc35480805987cb70be17b')
- check('password', 'salt', 2, 32, 'sha256',
- 'ae4d0c95af6b46d32d0adff928f06dd02a303f8ef3c251dfd6e2d85a95474c43')
- check('password', 'salt', 4096, 20, 'sha256',
- 'c5e478d59288c841aa530db6845c4c8d962893a0')
+ check(
+ "password",
+ "salt",
+ 1,
+ 32,
+ "sha256",
+ "120fb6cffcf8b32c43e7225256c4f837a86548c92ccc35480805987cb70be17b",
+ )
+ check(
+ "password",
+ "salt",
+ 2,
+ 32,
+ "sha256",
+ "ae4d0c95af6b46d32d0adff928f06dd02a303f8ef3c251dfd6e2d85a95474c43",
+ )
+ check(
+ "password",
+ "salt",
+ 4096,
+ 20,
+ "sha256",
+ "c5e478d59288c841aa530db6845c4c8d962893a0",
+ )
# This one is from the RFC but it just takes for ages
# check('password', 'salt', 16777216, 20,
# 'eefe3d61cd4da4e4e9945b3d6ba2158c2634e984')
# From Crypt-PBKDF2
- check('password', 'ATHENA.MIT.EDUraeburn', 1, 16, 'sha1',
- 'cdedb5281bb2f801565a1122b2563515')
- check('password', 'ATHENA.MIT.EDUraeburn', 1, 32, 'sha1',
- 'cdedb5281bb2f801565a1122b25635150ad1f7a04bb9f3a333ecc0e2e1f70837')
- check('password', 'ATHENA.MIT.EDUraeburn', 2, 16, 'sha1',
- '01dbee7f4a9e243e988b62c73cda935d')
- check('password', 'ATHENA.MIT.EDUraeburn', 2, 32, 'sha1',
- '01dbee7f4a9e243e988b62c73cda935da05378b93244ec8f48a99e61ad799d86')
- check('password', 'ATHENA.MIT.EDUraeburn', 1200, 32, 'sha1',
- '5c08eb61fdf71e4e4ec3cf6ba1f5512ba7e52ddbc5e5142f708a31e2e62b1e13')
- check('X' * 64, 'pass phrase equals block size', 1200, 32, 'sha1',
- '139c30c0966bc32ba55fdbf212530ac9c5ec59f1a452f5cc9ad940fea0598ed1')
- check('X' * 65, 'pass phrase exceeds block size', 1200, 32, 'sha1',
- '9ccad6d468770cd51b10e6a68721be611a8b4d282601db3b36be9246915ec82a')
+ check(
+ "password",
+ "ATHENA.MIT.EDUraeburn",
+ 1,
+ 16,
+ "sha1",
+ "cdedb5281bb2f801565a1122b2563515",
+ )
+ check(
+ "password",
+ "ATHENA.MIT.EDUraeburn",
+ 1,
+ 32,
+ "sha1",
+ "cdedb5281bb2f801565a1122b25635150ad1f7a04bb9f3a333ecc0e2e1f70837",
+ )
+ check(
+ "password",
+ "ATHENA.MIT.EDUraeburn",
+ 2,
+ 16,
+ "sha1",
+ "01dbee7f4a9e243e988b62c73cda935d",
+ )
+ check(
+ "password",
+ "ATHENA.MIT.EDUraeburn",
+ 2,
+ 32,
+ "sha1",
+ "01dbee7f4a9e243e988b62c73cda935da05378b93244ec8f48a99e61ad799d86",
+ )
+ check(
+ "password",
+ "ATHENA.MIT.EDUraeburn",
+ 1200,
+ 32,
+ "sha1",
+ "5c08eb61fdf71e4e4ec3cf6ba1f5512ba7e52ddbc5e5142f708a31e2e62b1e13",
+ )
+ check(
+ "X" * 64,
+ "pass phrase equals block size",
+ 1200,
+ 32,
+ "sha1",
+ "139c30c0966bc32ba55fdbf212530ac9c5ec59f1a452f5cc9ad940fea0598ed1",
+ )
+ check(
+ "X" * 65,
+ "pass phrase exceeds block size",
+ 1200,
+ 32,
+ "sha1",
+ "9ccad6d468770cd51b10e6a68721be611a8b4d282601db3b36be9246915ec82a",
+ )
diff --git a/tests/test_serving.py b/tests/test_serving.py
index 186003f3..94a46b7f 100644
--- a/tests/test_serving.py
+++ b/tests/test_serving.py
@@ -16,6 +16,13 @@ import sys
import textwrap
import time
+import pytest
+import requests.exceptions
+
+from werkzeug import __version__ as version
+from werkzeug import _reloader
+from werkzeug import serving
+
try:
import OpenSSL
except ImportError:
@@ -27,112 +34,119 @@ except ImportError:
watchdog = None
try:
- import httplib
-except ImportError:
from http import client as httplib
-
-import requests
-import requests.exceptions
-import pytest
-
-from werkzeug import __version__ as version, serving, _reloader
+except ImportError:
+ import httplib
def test_serving(dev_server):
- server = dev_server('from werkzeug.testapp import test_app as app')
- rv = requests.get('http://%s/?foo=bar&baz=blah' % server.addr).content
- assert b'WSGI Information' in rv
- assert b'foo=bar&amp;baz=blah' in rv
- assert b'Werkzeug/' + version.encode('ascii') in rv
+ server = dev_server("from werkzeug.testapp import test_app as app")
+ rv = requests.get("http://%s/?foo=bar&baz=blah" % server.addr).content
+ assert b"WSGI Information" in rv
+ assert b"foo=bar&amp;baz=blah" in rv
+ assert b"Werkzeug/" + version.encode("ascii") in rv
def test_absolute_requests(dev_server):
- server = dev_server('''
- def app(environ, start_response):
- assert environ['HTTP_HOST'] == 'surelynotexisting.example.com:1337'
- assert environ['PATH_INFO'] == '/index.htm'
- addr = environ['HTTP_X_WERKZEUG_ADDR']
- assert environ['SERVER_PORT'] == addr.split(':')[1]
- start_response('200 OK', [('Content-Type', 'text/html')])
- return [b'YES']
- ''')
+ server = dev_server(
+ """
+ def app(environ, start_response):
+ assert environ['HTTP_HOST'] == 'surelynotexisting.example.com:1337'
+ assert environ['PATH_INFO'] == '/index.htm'
+ addr = environ['HTTP_X_WERKZEUG_ADDR']
+ assert environ['SERVER_PORT'] == addr.split(':')[1]
+ start_response('200 OK', [('Content-Type', 'text/html')])
+ return [b'YES']
+ """
+ )
conn = httplib.HTTPConnection(server.addr)
- conn.request('GET', 'http://surelynotexisting.example.com:1337/index.htm#ignorethis',
- headers={'X-Werkzeug-Addr': server.addr})
+ conn.request(
+ "GET",
+ "http://surelynotexisting.example.com:1337/index.htm#ignorethis",
+ headers={"X-Werkzeug-Addr": server.addr},
+ )
res = conn.getresponse()
- assert res.read() == b'YES'
+ assert res.read() == b"YES"
def test_double_slash_path(dev_server):
- server = dev_server('''
- def app(environ, start_response):
- assert 'fail' not in environ['HTTP_HOST']
- start_response('200 OK', [('Content-Type', 'text/plain')])
- return [b'YES']
- ''')
+ server = dev_server(
+ """
+ def app(environ, start_response):
+ assert 'fail' not in environ['HTTP_HOST']
+ start_response('200 OK', [('Content-Type', 'text/plain')])
+ return [b'YES']
+ """
+ )
- r = requests.get(server.url + '//fail')
- assert r.content == b'YES'
+ r = requests.get(server.url + "//fail")
+ assert r.content == b"YES"
def test_broken_app(dev_server):
- server = dev_server('''
- def app(environ, start_response):
- 1 // 0
- ''')
-
- r = requests.get(server.url + '/?foo=bar&baz=blah')
+ server = dev_server(
+ """
+ def app(environ, start_response):
+ 1 // 0
+ """
+ )
+
+ r = requests.get(server.url + "/?foo=bar&baz=blah")
assert r.status_code == 500
- assert 'Internal Server Error' in r.text
+ assert "Internal Server Error" in r.text
-@pytest.mark.skipif(not hasattr(ssl, 'SSLContext'),
- reason='Missing PEP 466 (Python 2.7.9+) or Python 3.')
-@pytest.mark.skipif(OpenSSL is None,
- reason='OpenSSL is required for cert generation.')
+@pytest.mark.skipif(
+ not hasattr(ssl, "SSLContext"),
+ reason="Missing PEP 466 (Python 2.7.9+) or Python 3.",
+)
+@pytest.mark.skipif(OpenSSL is None, reason="OpenSSL is required for cert generation.")
def test_stdlib_ssl_contexts(dev_server, tmpdir):
- certificate, private_key = \
- serving.make_ssl_devcert(str(tmpdir.mkdir('certs')))
-
- server = dev_server('''
- def app(environ, start_response):
- start_response('200 OK', [('Content-Type', 'text/html')])
- return [b'hello']
-
- import ssl
- ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- ctx.load_cert_chain(r"%s", r"%s")
- kwargs['ssl_context'] = ctx
- ''' % (certificate, private_key))
+ certificate, private_key = serving.make_ssl_devcert(str(tmpdir.mkdir("certs")))
+
+ server = dev_server(
+ """
+ def app(environ, start_response):
+ start_response('200 OK', [('Content-Type', 'text/html')])
+ return [b'hello']
+
+ import ssl
+ ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ ctx.load_cert_chain(r"%s", r"%s")
+ kwargs['ssl_context'] = ctx
+ """
+ % (certificate, private_key)
+ )
assert server.addr is not None
r = requests.get(server.url, verify=False)
- assert r.content == b'hello'
+ assert r.content == b"hello"
-@pytest.mark.skipif(OpenSSL is None, reason='OpenSSL is not installed.')
+@pytest.mark.skipif(OpenSSL is None, reason="OpenSSL is not installed.")
def test_ssl_context_adhoc(dev_server):
- server = dev_server('''
- def app(environ, start_response):
- start_response('200 OK', [('Content-Type', 'text/html')])
- return [b'hello']
-
- kwargs['ssl_context'] = 'adhoc'
- ''')
+ server = dev_server(
+ """
+ def app(environ, start_response):
+ start_response('200 OK', [('Content-Type', 'text/html')])
+ return [b'hello']
+
+ kwargs['ssl_context'] = 'adhoc'
+ """
+ )
r = requests.get(server.url, verify=False)
- assert r.content == b'hello'
+ assert r.content == b"hello"
-@pytest.mark.skipif(OpenSSL is None, reason='OpenSSL is not installed.')
+@pytest.mark.skipif(OpenSSL is None, reason="OpenSSL is not installed.")
def test_make_ssl_devcert(tmpdir):
- certificate, private_key = \
- serving.make_ssl_devcert(str(tmpdir))
+ certificate, private_key = serving.make_ssl_devcert(str(tmpdir))
assert os.path.isfile(certificate)
assert os.path.isfile(private_key)
-@pytest.mark.skipif(watchdog is None, reason='Watchdog not installed.')
+@pytest.mark.skipif(watchdog is None, reason="Watchdog not installed.")
def test_reloader_broken_imports(tmpdir, dev_server):
# We explicitly assert that the server reloads on change, even though in
# this case the import could've just been retried. This is to assert
@@ -142,108 +156,126 @@ def test_reloader_broken_imports(tmpdir, dev_server):
# of directories, this only works for the watchdog reloader. The stat
# reloader is too inefficient to watch such a large amount of files.
- real_app = tmpdir.join('real_app.py')
+ real_app = tmpdir.join("real_app.py")
real_app.write("lol syntax error")
- server = dev_server('''
- trials = []
- def app(environ, start_response):
- assert not trials, 'should have reloaded'
- trials.append(1)
- import real_app
- return real_app.real_app(environ, start_response)
-
- kwargs['use_reloader'] = True
- kwargs['reloader_interval'] = 0.1
- kwargs['reloader_type'] = 'watchdog'
- ''')
+ server = dev_server(
+ """
+ trials = []
+ def app(environ, start_response):
+ assert not trials, 'should have reloaded'
+ trials.append(1)
+ import real_app
+ return real_app.real_app(environ, start_response)
+
+ kwargs['use_reloader'] = True
+ kwargs['reloader_interval'] = 0.1
+ kwargs['reloader_type'] = 'watchdog'
+ """
+ )
server.wait_for_reloader_loop()
r = requests.get(server.url)
assert r.status_code == 500
- real_app.write(textwrap.dedent('''
- def real_app(environ, start_response):
- start_response('200 OK', [('Content-Type', 'text/html')])
- return [b'hello']
- '''))
+ real_app.write(
+ textwrap.dedent(
+ """
+ def real_app(environ, start_response):
+ start_response('200 OK', [('Content-Type', 'text/html')])
+ return [b'hello']
+ """
+ )
+ )
server.wait_for_reloader()
r = requests.get(server.url)
assert r.status_code == 200
- assert r.content == b'hello'
+ assert r.content == b"hello"
-@pytest.mark.skipif(watchdog is None, reason='Watchdog not installed.')
+@pytest.mark.skipif(watchdog is None, reason="Watchdog not installed.")
def test_reloader_nested_broken_imports(tmpdir, dev_server):
- real_app = tmpdir.mkdir('real_app')
- real_app.join('__init__.py').write('from real_app.sub import real_app')
- sub = real_app.mkdir('sub').join('__init__.py')
+ real_app = tmpdir.mkdir("real_app")
+ real_app.join("__init__.py").write("from real_app.sub import real_app")
+ sub = real_app.mkdir("sub").join("__init__.py")
sub.write("lol syntax error")
- server = dev_server('''
- trials = []
- def app(environ, start_response):
- assert not trials, 'should have reloaded'
- trials.append(1)
- import real_app
- return real_app.real_app(environ, start_response)
-
- kwargs['use_reloader'] = True
- kwargs['reloader_interval'] = 0.1
- kwargs['reloader_type'] = 'watchdog'
- ''')
+ server = dev_server(
+ """
+ trials = []
+ def app(environ, start_response):
+ assert not trials, 'should have reloaded'
+ trials.append(1)
+ import real_app
+ return real_app.real_app(environ, start_response)
+
+ kwargs['use_reloader'] = True
+ kwargs['reloader_interval'] = 0.1
+ kwargs['reloader_type'] = 'watchdog'
+ """
+ )
server.wait_for_reloader_loop()
r = requests.get(server.url)
assert r.status_code == 500
- sub.write(textwrap.dedent('''
- def real_app(environ, start_response):
- start_response('200 OK', [('Content-Type', 'text/html')])
- return [b'hello']
- '''))
+ sub.write(
+ textwrap.dedent(
+ """
+ def real_app(environ, start_response):
+ start_response('200 OK', [('Content-Type', 'text/html')])
+ return [b'hello']
+ """
+ )
+ )
server.wait_for_reloader()
r = requests.get(server.url)
assert r.status_code == 200
- assert r.content == b'hello'
+ assert r.content == b"hello"
-@pytest.mark.skipif(watchdog is None, reason='Watchdog not installed.')
+@pytest.mark.skipif(watchdog is None, reason="Watchdog not installed.")
def test_reloader_reports_correct_file(tmpdir, dev_server):
- real_app = tmpdir.join('real_app.py')
- real_app.write(textwrap.dedent('''
- def real_app(environ, start_response):
- start_response('200 OK', [('Content-Type', 'text/html')])
- return [b'hello']
- '''))
-
- server = dev_server('''
- trials = []
- def app(environ, start_response):
- assert not trials, 'should have reloaded'
- trials.append(1)
- import real_app
- return real_app.real_app(environ, start_response)
-
- kwargs['use_reloader'] = True
- kwargs['reloader_interval'] = 0.1
- kwargs['reloader_type'] = 'watchdog'
- ''')
+ real_app = tmpdir.join("real_app.py")
+ real_app.write(
+ textwrap.dedent(
+ """
+ def real_app(environ, start_response):
+ start_response('200 OK', [('Content-Type', 'text/html')])
+ return [b'hello']
+ """
+ )
+ )
+
+ server = dev_server(
+ """
+ trials = []
+ def app(environ, start_response):
+ assert not trials, 'should have reloaded'
+ trials.append(1)
+ import real_app
+ return real_app.real_app(environ, start_response)
+
+ kwargs['use_reloader'] = True
+ kwargs['reloader_interval'] = 0.1
+ kwargs['reloader_type'] = 'watchdog'
+ """
+ )
server.wait_for_reloader_loop()
r = requests.get(server.url)
assert r.status_code == 200
- assert r.content == b'hello'
+ assert r.content == b"hello"
- real_app_binary = tmpdir.join('real_app.pyc')
- real_app_binary.write('anything is fine here')
+ real_app_binary = tmpdir.join("real_app.pyc")
+ real_app_binary.write("anything is fine here")
server.wait_for_reloader()
change_event = " * Detected change in '%(path)s', reloading" % {
# need to double escape Windows paths
- 'path': str(real_app_binary).replace("\\", "\\\\")
+ "path": str(real_app_binary).replace("\\", "\\\\")
}
server.logfile.seek(0)
for i in range(20):
@@ -252,17 +284,17 @@ def test_reloader_reports_correct_file(tmpdir, dev_server):
if change_event in log:
break
else:
- raise RuntimeError('Change event not detected.')
+ raise RuntimeError("Change event not detected.")
def test_windows_get_args_for_reloading(monkeypatch, tmpdir):
- test_py_exe = r'C:\Users\test\AppData\Local\Programs\Python\Python36\python.exe'
- monkeypatch.setattr(os, 'name', 'nt')
- monkeypatch.setattr(sys, 'executable', test_py_exe)
- test_exe = tmpdir.mkdir('test').join('test.exe')
- monkeypatch.setattr(sys, 'argv', [test_exe.strpath, 'run'])
+ test_py_exe = r"C:\Users\test\AppData\Local\Programs\Python\Python36\python.exe"
+ monkeypatch.setattr(os, "name", "nt")
+ monkeypatch.setattr(sys, "executable", test_py_exe)
+ test_exe = tmpdir.mkdir("test").join("test.exe")
+ monkeypatch.setattr(sys, "argv", [test_exe.strpath, "run"])
rv = _reloader._get_args_for_reloading()
- assert rv == [test_exe.strpath, 'run']
+ assert rv == [test_exe.strpath, "run"]
def test_monkeypatched_sleep(tmpdir):
@@ -273,20 +305,24 @@ def test_monkeypatched_sleep(tmpdir):
# `eventlet.monkey_patch` before importing `_reloader`, `time.sleep` is a
# python function, and subsequently calling `ReloaderLoop._sleep` fails
# with a TypeError. This test checks that _sleep is attached correctly.
- script = tmpdir.mkdir('app').join('test.py')
- script.write(textwrap.dedent('''
- import time
-
- def sleep(secs):
- pass
-
- # simulate eventlet.monkey_patch by replacing the builtin sleep
- # with a regular function before _reloader is imported
- time.sleep = sleep
-
- from werkzeug._reloader import ReloaderLoop
- ReloaderLoop()._sleep(0)
- '''))
+ script = tmpdir.mkdir("app").join("test.py")
+ script.write(
+ textwrap.dedent(
+ """
+ import time
+
+ def sleep(secs):
+ pass
+
+ # simulate eventlet.monkey_patch by replacing the builtin sleep
+ # with a regular function before _reloader is imported
+ time.sleep = sleep
+
+ from werkzeug._reloader import ReloaderLoop
+ ReloaderLoop()._sleep(0)
+ """
+ )
+ )
subprocess.check_call([sys.executable, str(script)])
@@ -295,185 +331,187 @@ def test_wrong_protocol(dev_server):
# traceback
# See https://github.com/pallets/werkzeug/pull/838
- server = dev_server('''
- def app(environ, start_response):
- start_response('200 OK', [('Content-Type', 'text/html')])
- return [b'hello']
- ''')
+ server = dev_server(
+ """
+ def app(environ, start_response):
+ start_response('200 OK', [('Content-Type', 'text/html')])
+ return [b'hello']
+ """
+ )
with pytest.raises(requests.exceptions.ConnectionError):
- requests.get('https://%s/' % server.addr)
+ requests.get("https://%s/" % server.addr)
log = server.logfile.read()
- assert 'Traceback' not in log
- assert '\n127.0.0.1' in log
+ assert "Traceback" not in log
+ assert "\n127.0.0.1" in log
def test_absent_content_length_and_content_type(dev_server):
- server = dev_server('''
- def app(environ, start_response):
- assert 'CONTENT_LENGTH' not in environ
- assert 'CONTENT_TYPE' not in environ
- start_response('200 OK', [('Content-Type', 'text/html')])
- return [b'YES']
- ''')
+ server = dev_server(
+ """
+ def app(environ, start_response):
+ assert 'CONTENT_LENGTH' not in environ
+ assert 'CONTENT_TYPE' not in environ
+ start_response('200 OK', [('Content-Type', 'text/html')])
+ return [b'YES']
+ """
+ )
r = requests.get(server.url)
- assert r.content == b'YES'
+ assert r.content == b"YES"
def test_set_content_length_and_content_type_if_provided_by_client(dev_server):
- server = dev_server('''
- def app(environ, start_response):
- assert environ['CONTENT_LENGTH'] == '233'
- assert environ['CONTENT_TYPE'] == 'application/json'
- start_response('200 OK', [('Content-Type', 'text/html')])
- return [b'YES']
- ''')
-
- r = requests.get(server.url, headers={
- 'content_length': '233',
- 'content_type': 'application/json'
- })
- assert r.content == b'YES'
+ server = dev_server(
+ """
+ def app(environ, start_response):
+ assert environ['CONTENT_LENGTH'] == '233'
+ assert environ['CONTENT_TYPE'] == 'application/json'
+ start_response('200 OK', [('Content-Type', 'text/html')])
+ return [b'YES']
+ """
+ )
+
+ r = requests.get(
+ server.url,
+ headers={"content_length": "233", "content_type": "application/json"},
+ )
+ assert r.content == b"YES"
def test_port_must_be_integer(dev_server):
def app(environ, start_response):
- start_response('200 OK', [('Content-Type', 'text/html')])
- return [b'hello']
+ start_response("200 OK", [("Content-Type", "text/html")])
+ return [b"hello"]
with pytest.raises(TypeError) as excinfo:
- serving.run_simple(hostname='localhost', port='5001',
- application=app, use_reloader=True)
- assert 'port must be an integer' in str(excinfo.value)
+ serving.run_simple(
+ hostname="localhost", port="5001", application=app, use_reloader=True
+ )
+ assert "port must be an integer" in str(excinfo.value)
with pytest.raises(TypeError) as excinfo:
- serving.run_simple(hostname='localhost', port='5001',
- application=app, use_reloader=False)
- assert 'port must be an integer' in str(excinfo.value)
+ serving.run_simple(
+ hostname="localhost", port="5001", application=app, use_reloader=False
+ )
+ assert "port must be an integer" in str(excinfo.value)
def test_chunked_encoding(dev_server):
- server = dev_server(r'''
- from werkzeug.wrappers import Request
- def app(environ, start_response):
- assert environ['HTTP_TRANSFER_ENCODING'] == 'chunked'
- assert environ.get('wsgi.input_terminated', False)
- request = Request(environ)
- assert request.mimetype == 'multipart/form-data'
- assert request.files['file'].read() == b'This is a test\n'
- assert request.form['type'] == 'text/plain'
- start_response('200 OK', [('Content-Type', 'text/plain')])
- return [b'YES']
- ''')
-
- testfile = os.path.join(os.path.dirname(__file__), 'res', 'chunked.txt')
-
- if sys.version_info[0] == 2:
- from httplib import HTTPConnection
- else:
- from http.client import HTTPConnection
-
- conn = HTTPConnection('127.0.0.1', server.port)
+ server = dev_server(
+ r"""
+ from werkzeug.wrappers import Request
+ def app(environ, start_response):
+ assert environ['HTTP_TRANSFER_ENCODING'] == 'chunked'
+ assert environ.get('wsgi.input_terminated', False)
+ request = Request(environ)
+ assert request.mimetype == 'multipart/form-data'
+ assert request.files['file'].read() == b'This is a test\n'
+ assert request.form['type'] == 'text/plain'
+ start_response('200 OK', [('Content-Type', 'text/plain')])
+ return [b'YES']
+ """
+ )
+
+ testfile = os.path.join(os.path.dirname(__file__), "res", "chunked.http")
+
+ conn = httplib.HTTPConnection("127.0.0.1", server.port)
conn.connect()
- conn.putrequest('POST', '/', skip_host=1, skip_accept_encoding=1)
- conn.putheader('Accept', 'text/plain')
- conn.putheader('Transfer-Encoding', 'chunked')
+ conn.putrequest("POST", "/", skip_host=1, skip_accept_encoding=1)
+ conn.putheader("Accept", "text/plain")
+ conn.putheader("Transfer-Encoding", "chunked")
conn.putheader(
- 'Content-Type',
- 'multipart/form-data; boundary='
- '--------------------------898239224156930639461866')
+ "Content-Type",
+ "multipart/form-data; boundary="
+ "--------------------------898239224156930639461866",
+ )
conn.endheaders()
- with open(testfile, 'rb') as f:
+ with open(testfile, "rb") as f:
conn.send(f.read())
res = conn.getresponse()
assert res.status == 200
- assert res.read() == b'YES'
+ assert res.read() == b"YES"
conn.close()
def test_chunked_encoding_with_content_length(dev_server):
- server = dev_server(r'''
- from werkzeug.wrappers import Request
- def app(environ, start_response):
- assert environ['HTTP_TRANSFER_ENCODING'] == 'chunked'
- assert environ.get('wsgi.input_terminated', False)
- request = Request(environ)
- assert request.mimetype == 'multipart/form-data'
- assert request.files['file'].read() == b'This is a test\n'
- assert request.form['type'] == 'text/plain'
- start_response('200 OK', [('Content-Type', 'text/plain')])
- return [b'YES']
- ''')
-
- testfile = os.path.join(os.path.dirname(__file__), 'res', 'chunked.txt')
-
- if sys.version_info[0] == 2:
- from httplib import HTTPConnection
- else:
- from http.client import HTTPConnection
-
- conn = HTTPConnection('127.0.0.1', server.port)
+ server = dev_server(
+ r"""
+ from werkzeug.wrappers import Request
+ def app(environ, start_response):
+ assert environ['HTTP_TRANSFER_ENCODING'] == 'chunked'
+ assert environ.get('wsgi.input_terminated', False)
+ request = Request(environ)
+ assert request.mimetype == 'multipart/form-data'
+ assert request.files['file'].read() == b'This is a test\n'
+ assert request.form['type'] == 'text/plain'
+ start_response('200 OK', [('Content-Type', 'text/plain')])
+ return [b'YES']
+ """
+ )
+
+ testfile = os.path.join(os.path.dirname(__file__), "res", "chunked.http")
+
+ conn = httplib.HTTPConnection("127.0.0.1", server.port)
conn.connect()
- conn.putrequest('POST', '/', skip_host=1, skip_accept_encoding=1)
- conn.putheader('Accept', 'text/plain')
- conn.putheader('Transfer-Encoding', 'chunked')
+ conn.putrequest("POST", "/", skip_host=1, skip_accept_encoding=1)
+ conn.putheader("Accept", "text/plain")
+ conn.putheader("Transfer-Encoding", "chunked")
# Content-Length is invalid for chunked, but some libraries might send it
- conn.putheader('Content-Length', '372')
+ conn.putheader("Content-Length", "372")
conn.putheader(
- 'Content-Type',
- 'multipart/form-data; boundary='
- '--------------------------898239224156930639461866')
+ "Content-Type",
+ "multipart/form-data; boundary="
+ "--------------------------898239224156930639461866",
+ )
conn.endheaders()
- with open(testfile, 'rb') as f:
+ with open(testfile, "rb") as f:
conn.send(f.read())
res = conn.getresponse()
assert res.status == 200
- assert res.read() == b'YES'
+ assert res.read() == b"YES"
conn.close()
def test_multiple_headers_concatenated_per_rfc_3875_section_4_1_18(dev_server):
- server = dev_server(r'''
- from werkzeug.wrappers import Response
- def app(environ, start_response):
- start_response('200 OK', [('Content-Type', 'text/plain')])
- return [environ['HTTP_XYZ'].encode()]
- ''')
-
- if sys.version_info[0] == 2:
- from httplib import HTTPConnection
- else:
- from http.client import HTTPConnection
- conn = HTTPConnection('127.0.0.1', server.port)
+ server = dev_server(
+ r"""
+ from werkzeug.wrappers import Response
+ def app(environ, start_response):
+ start_response('200 OK', [('Content-Type', 'text/plain')])
+ return [environ['HTTP_XYZ'].encode()]
+ """
+ )
+
+ conn = httplib.HTTPConnection("127.0.0.1", server.port)
conn.connect()
- conn.putrequest('GET', '/')
- conn.putheader('Accept', 'text/plain')
- conn.putheader('XYZ', ' a ')
- conn.putheader('X-INGNORE-1', 'Some nonsense')
- conn.putheader('XYZ', ' b')
- conn.putheader('X-INGNORE-2', 'Some nonsense')
- conn.putheader('XYZ', 'c ')
- conn.putheader('X-INGNORE-3', 'Some nonsense')
- conn.putheader('XYZ', 'd')
+ conn.putrequest("GET", "/")
+ conn.putheader("Accept", "text/plain")
+ conn.putheader("XYZ", " a ")
+ conn.putheader("X-INGNORE-1", "Some nonsense")
+ conn.putheader("XYZ", " b")
+ conn.putheader("X-INGNORE-2", "Some nonsense")
+ conn.putheader("XYZ", "c ")
+ conn.putheader("X-INGNORE-3", "Some nonsense")
+ conn.putheader("XYZ", "d")
conn.endheaders()
- conn.send(b'')
+ conn.send(b"")
res = conn.getresponse()
assert res.status == 200
- assert res.read() == b'a ,b,c ,d'
+ assert res.read() == b"a ,b,c ,d"
conn.close()
def can_test_unix_socket():
- if not hasattr(socket, 'AF_UNIX'):
+ if not hasattr(socket, "AF_UNIX"):
return False
try:
import requests_unixsocket # noqa: F401
@@ -482,11 +520,15 @@ def can_test_unix_socket():
return True
-@pytest.mark.skipif(not can_test_unix_socket(), reason='Only works on UNIX')
+@pytest.mark.skipif(not can_test_unix_socket(), reason="Only works on UNIX")
def test_unix_socket(tmpdir, dev_server):
- socket_f = str(tmpdir.join('socket'))
- dev_server('''
- app = None
- kwargs['hostname'] = {socket!r}
- '''.format(socket='unix://' + socket_f))
+ socket_f = str(tmpdir.join("socket"))
+ dev_server(
+ """
+ app = None
+ kwargs['hostname'] = {socket!r}
+ """.format(
+ socket="unix://" + socket_f
+ )
+ )
assert os.path.exists(socket_f)
diff --git a/tests/test_test.py b/tests/test_test.py
index 82c4f3f9..e7f6306b 100644
--- a/tests/test_test.py
+++ b/tests/test_test.py
@@ -8,168 +8,176 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-
import json
-import pytest
-
import sys
-from io import BytesIO
-from werkzeug._compat import iteritems, to_bytes, implements_iterator
from functools import partial
+from io import BytesIO
-from tests import strict_eq
+import pytest
-from werkzeug.wrappers import Request, Response, BaseResponse
-from werkzeug.test import Client, EnvironBuilder, create_environ, \
- ClientRedirectError, stream_encode_multipart, run_wsgi_app
-from werkzeug.utils import redirect
+from . import strict_eq
+from werkzeug._compat import implements_iterator
+from werkzeug._compat import iteritems
+from werkzeug._compat import to_bytes
+from werkzeug.datastructures import FileStorage
+from werkzeug.datastructures import MultiDict
from werkzeug.formparser import parse_form_data
-from werkzeug.datastructures import MultiDict, FileStorage
+from werkzeug.test import Client
+from werkzeug.test import ClientRedirectError
+from werkzeug.test import create_environ
+from werkzeug.test import EnvironBuilder
+from werkzeug.test import run_wsgi_app
+from werkzeug.test import stream_encode_multipart
+from werkzeug.utils import redirect
+from werkzeug.wrappers import BaseResponse
+from werkzeug.wrappers import Request
+from werkzeug.wrappers import Response
def cookie_app(environ, start_response):
"""A WSGI application which sets a cookie, and returns as a response any
cookie which exists.
"""
- response = Response(environ.get('HTTP_COOKIE', 'No Cookie'),
- mimetype='text/plain')
- response.set_cookie('test', 'test')
+ response = Response(environ.get("HTTP_COOKIE", "No Cookie"), mimetype="text/plain")
+ response.set_cookie("test", "test")
return response(environ, start_response)
def redirect_loop_app(environ, start_response):
- response = redirect('http://localhost/some/redirect/')
+ response = redirect("http://localhost/some/redirect/")
return response(environ, start_response)
def redirect_with_get_app(environ, start_response):
req = Request(environ)
- if req.url not in ('http://localhost/',
- 'http://localhost/first/request',
- 'http://localhost/some/redirect/'):
+ if req.url not in (
+ "http://localhost/",
+ "http://localhost/first/request",
+ "http://localhost/some/redirect/",
+ ):
assert False, 'redirect_demo_app() did not expect URL "%s"' % req.url
- if '/some/redirect' not in req.url:
- response = redirect('http://localhost/some/redirect/')
+ if "/some/redirect" not in req.url:
+ response = redirect("http://localhost/some/redirect/")
else:
- response = Response('current url: %s' % req.url)
+ response = Response("current url: %s" % req.url)
return response(environ, start_response)
def external_redirect_demo_app(environ, start_response):
- response = redirect('http://example.com/')
+ response = redirect("http://example.com/")
return response(environ, start_response)
def external_subdomain_redirect_demo_app(environ, start_response):
- if 'test.example.com' in environ['HTTP_HOST']:
- response = Response('redirected successfully to subdomain')
+ if "test.example.com" in environ["HTTP_HOST"]:
+ response = Response("redirected successfully to subdomain")
else:
- response = redirect('http://test.example.com/login')
+ response = redirect("http://test.example.com/login")
return response(environ, start_response)
def multi_value_post_app(environ, start_response):
req = Request(environ)
- assert req.form['field'] == 'val1', req.form['field']
- assert req.form.getlist('field') == ['val1', 'val2'], req.form.getlist('field')
- response = Response('ok')
+ assert req.form["field"] == "val1", req.form["field"]
+ assert req.form.getlist("field") == ["val1", "val2"], req.form.getlist("field")
+ response = Response("ok")
return response(environ, start_response)
def test_cookie_forging():
c = Client(cookie_app)
- c.set_cookie('localhost', 'foo', 'bar')
+ c.set_cookie("localhost", "foo", "bar")
appiter, code, headers = c.open()
- strict_eq(list(appiter), [b'foo=bar'])
+ strict_eq(list(appiter), [b"foo=bar"])
def test_set_cookie_app():
c = Client(cookie_app)
appiter, code, headers = c.open()
- assert 'Set-Cookie' in dict(headers)
+ assert "Set-Cookie" in dict(headers)
def test_cookiejar_stores_cookie():
c = Client(cookie_app)
appiter, code, headers = c.open()
- assert 'test' in c.cookie_jar._cookies['localhost.local']['/']
+ assert "test" in c.cookie_jar._cookies["localhost.local"]["/"]
def test_no_initial_cookie():
c = Client(cookie_app)
appiter, code, headers = c.open()
- strict_eq(b''.join(appiter), b'No Cookie')
+ strict_eq(b"".join(appiter), b"No Cookie")
def test_resent_cookie():
c = Client(cookie_app)
c.open()
appiter, code, headers = c.open()
- strict_eq(b''.join(appiter), b'test=test')
+ strict_eq(b"".join(appiter), b"test=test")
def test_disable_cookies():
c = Client(cookie_app, use_cookies=False)
c.open()
appiter, code, headers = c.open()
- strict_eq(b''.join(appiter), b'No Cookie')
+ strict_eq(b"".join(appiter), b"No Cookie")
def test_cookie_for_different_path():
c = Client(cookie_app)
- c.open('/path1')
- appiter, code, headers = c.open('/path2')
- strict_eq(b''.join(appiter), b'test=test')
+ c.open("/path1")
+ appiter, code, headers = c.open("/path2")
+ strict_eq(b"".join(appiter), b"test=test")
def test_environ_builder_basics():
b = EnvironBuilder()
assert b.content_type is None
- b.method = 'POST'
+ b.method = "POST"
assert b.content_type is None
- b.form['test'] = 'normal value'
- assert b.content_type == 'application/x-www-form-urlencoded'
- b.files.add_file('test', BytesIO(b'test contents'), 'test.txt')
- assert b.files['test'].content_type == 'text/plain'
- b.form['test_int'] = 1
- assert b.content_type == 'multipart/form-data'
+ b.form["test"] = "normal value"
+ assert b.content_type == "application/x-www-form-urlencoded"
+ b.files.add_file("test", BytesIO(b"test contents"), "test.txt")
+ assert b.files["test"].content_type == "text/plain"
+ b.form["test_int"] = 1
+ assert b.content_type == "multipart/form-data"
req = b.get_request()
b.close()
- strict_eq(req.url, u'http://localhost/')
- strict_eq(req.method, 'POST')
- strict_eq(req.form['test'], u'normal value')
- assert req.files['test'].content_type == 'text/plain'
- strict_eq(req.files['test'].filename, u'test.txt')
- strict_eq(req.files['test'].read(), b'test contents')
+ strict_eq(req.url, u"http://localhost/")
+ strict_eq(req.method, "POST")
+ strict_eq(req.form["test"], u"normal value")
+ assert req.files["test"].content_type == "text/plain"
+ strict_eq(req.files["test"].filename, u"test.txt")
+ strict_eq(req.files["test"].read(), b"test contents")
def test_environ_builder_data():
- b = EnvironBuilder(data='foo')
- assert b.input_stream.getvalue() == b'foo'
- b = EnvironBuilder(data=b'foo')
- assert b.input_stream.getvalue() == b'foo'
+ b = EnvironBuilder(data="foo")
+ assert b.input_stream.getvalue() == b"foo"
+ b = EnvironBuilder(data=b"foo")
+ assert b.input_stream.getvalue() == b"foo"
- b = EnvironBuilder(data={'foo': 'bar'})
- assert b.form['foo'] == 'bar'
- b = EnvironBuilder(data={'foo': ['bar1', 'bar2']})
- assert b.form.getlist('foo') == ['bar1', 'bar2']
+ b = EnvironBuilder(data={"foo": "bar"})
+ assert b.form["foo"] == "bar"
+ b = EnvironBuilder(data={"foo": ["bar1", "bar2"]})
+ assert b.form.getlist("foo") == ["bar1", "bar2"]
def check_list_content(b, length):
- foo = b.files.getlist('foo')
+ foo = b.files.getlist("foo")
assert len(foo) == length
for obj in foo:
assert isinstance(obj, FileStorage)
- b = EnvironBuilder(data={'foo': BytesIO()})
+ b = EnvironBuilder(data={"foo": BytesIO()})
check_list_content(b, 1)
- b = EnvironBuilder(data={'foo': [BytesIO(), BytesIO()]})
+ b = EnvironBuilder(data={"foo": [BytesIO(), BytesIO()]})
check_list_content(b, 2)
- b = EnvironBuilder(data={'foo': (BytesIO(),)})
+ b = EnvironBuilder(data={"foo": (BytesIO(),)})
check_list_content(b, 1)
- b = EnvironBuilder(data={'foo': [(BytesIO(),), (BytesIO(),)]})
+ b = EnvironBuilder(data={"foo": [(BytesIO(),), (BytesIO(),)]})
check_list_content(b, 2)
@@ -188,151 +196,157 @@ def test_environ_builder_json():
def test_environ_builder_headers():
- b = EnvironBuilder(environ_base={'HTTP_USER_AGENT': 'Foo/0.1'},
- environ_overrides={'wsgi.version': (1, 1)})
- b.headers['X-Beat-My-Horse'] = 'very well sir'
+ b = EnvironBuilder(
+ environ_base={"HTTP_USER_AGENT": "Foo/0.1"},
+ environ_overrides={"wsgi.version": (1, 1)},
+ )
+ b.headers["X-Beat-My-Horse"] = "very well sir"
env = b.get_environ()
- strict_eq(env['HTTP_USER_AGENT'], 'Foo/0.1')
- strict_eq(env['HTTP_X_BEAT_MY_HORSE'], 'very well sir')
- strict_eq(env['wsgi.version'], (1, 1))
+ strict_eq(env["HTTP_USER_AGENT"], "Foo/0.1")
+ strict_eq(env["HTTP_X_BEAT_MY_HORSE"], "very well sir")
+ strict_eq(env["wsgi.version"], (1, 1))
- b.headers['User-Agent'] = 'Bar/1.0'
+ b.headers["User-Agent"] = "Bar/1.0"
env = b.get_environ()
- strict_eq(env['HTTP_USER_AGENT'], 'Bar/1.0')
+ strict_eq(env["HTTP_USER_AGENT"], "Bar/1.0")
def test_environ_builder_headers_content_type():
- b = EnvironBuilder(headers={'Content-Type': 'text/plain'})
+ b = EnvironBuilder(headers={"Content-Type": "text/plain"})
env = b.get_environ()
- assert env['CONTENT_TYPE'] == 'text/plain'
- b = EnvironBuilder(content_type='text/html',
- headers={'Content-Type': 'text/plain'})
+ assert env["CONTENT_TYPE"] == "text/plain"
+ b = EnvironBuilder(content_type="text/html", headers={"Content-Type": "text/plain"})
env = b.get_environ()
- assert env['CONTENT_TYPE'] == 'text/html'
+ assert env["CONTENT_TYPE"] == "text/html"
b = EnvironBuilder()
env = b.get_environ()
- assert 'CONTENT_TYPE' not in env
+ assert "CONTENT_TYPE" not in env
def test_environ_builder_paths():
- b = EnvironBuilder(path='/foo', base_url='http://example.com/')
- strict_eq(b.base_url, 'http://example.com/')
- strict_eq(b.path, '/foo')
- strict_eq(b.script_root, '')
- strict_eq(b.host, 'example.com')
-
- b = EnvironBuilder(path='/foo', base_url='http://example.com/bar')
- strict_eq(b.base_url, 'http://example.com/bar/')
- strict_eq(b.path, '/foo')
- strict_eq(b.script_root, '/bar')
- strict_eq(b.host, 'example.com')
-
- b.host = 'localhost'
- strict_eq(b.base_url, 'http://localhost/bar/')
- b.base_url = 'http://localhost:8080/'
- strict_eq(b.host, 'localhost:8080')
- strict_eq(b.server_name, 'localhost')
+ b = EnvironBuilder(path="/foo", base_url="http://example.com/")
+ strict_eq(b.base_url, "http://example.com/")
+ strict_eq(b.path, "/foo")
+ strict_eq(b.script_root, "")
+ strict_eq(b.host, "example.com")
+
+ b = EnvironBuilder(path="/foo", base_url="http://example.com/bar")
+ strict_eq(b.base_url, "http://example.com/bar/")
+ strict_eq(b.path, "/foo")
+ strict_eq(b.script_root, "/bar")
+ strict_eq(b.host, "example.com")
+
+ b.host = "localhost"
+ strict_eq(b.base_url, "http://localhost/bar/")
+ b.base_url = "http://localhost:8080/"
+ strict_eq(b.host, "localhost:8080")
+ strict_eq(b.server_name, "localhost")
strict_eq(b.server_port, 8080)
- b.host = 'foo.invalid'
- b.url_scheme = 'https'
- b.script_root = '/test'
+ b.host = "foo.invalid"
+ b.url_scheme = "https"
+ b.script_root = "/test"
env = b.get_environ()
- strict_eq(env['SERVER_NAME'], 'foo.invalid')
- strict_eq(env['SERVER_PORT'], '443')
- strict_eq(env['SCRIPT_NAME'], '/test')
- strict_eq(env['PATH_INFO'], '/foo')
- strict_eq(env['HTTP_HOST'], 'foo.invalid')
- strict_eq(env['wsgi.url_scheme'], 'https')
- strict_eq(b.base_url, 'https://foo.invalid/test/')
+ strict_eq(env["SERVER_NAME"], "foo.invalid")
+ strict_eq(env["SERVER_PORT"], "443")
+ strict_eq(env["SCRIPT_NAME"], "/test")
+ strict_eq(env["PATH_INFO"], "/foo")
+ strict_eq(env["HTTP_HOST"], "foo.invalid")
+ strict_eq(env["wsgi.url_scheme"], "https")
+ strict_eq(b.base_url, "https://foo.invalid/test/")
def test_environ_builder_content_type():
builder = EnvironBuilder()
assert builder.content_type is None
- builder.method = 'POST'
+ builder.method = "POST"
assert builder.content_type is None
- builder.method = 'PUT'
+ builder.method = "PUT"
assert builder.content_type is None
- builder.method = 'PATCH'
+ builder.method = "PATCH"
assert builder.content_type is None
- builder.method = 'DELETE'
+ builder.method = "DELETE"
assert builder.content_type is None
- builder.method = 'GET'
+ builder.method = "GET"
assert builder.content_type is None
- builder.form['foo'] = 'bar'
- assert builder.content_type == 'application/x-www-form-urlencoded'
- builder.files.add_file('blafasel', BytesIO(b'foo'), 'test.txt')
- assert builder.content_type == 'multipart/form-data'
+ builder.form["foo"] = "bar"
+ assert builder.content_type == "application/x-www-form-urlencoded"
+ builder.files.add_file("blafasel", BytesIO(b"foo"), "test.txt")
+ assert builder.content_type == "multipart/form-data"
req = builder.get_request()
- strict_eq(req.form['foo'], u'bar')
- strict_eq(req.files['blafasel'].read(), b'foo')
+ strict_eq(req.form["foo"], u"bar")
+ strict_eq(req.files["blafasel"].read(), b"foo")
def test_environ_builder_stream_switch():
- d = MultiDict(dict(foo=u'bar', blub=u'blah', hu=u'hum'))
+ d = MultiDict(dict(foo=u"bar", blub=u"blah", hu=u"hum"))
for use_tempfile in False, True:
stream, length, boundary = stream_encode_multipart(
- d, use_tempfile, threshold=150)
+ d, use_tempfile, threshold=150
+ )
assert isinstance(stream, BytesIO) != use_tempfile
- form = parse_form_data({'wsgi.input': stream, 'CONTENT_LENGTH': str(length),
- 'CONTENT_TYPE': 'multipart/form-data; boundary="%s"' %
- boundary})[1]
+ form = parse_form_data(
+ {
+ "wsgi.input": stream,
+ "CONTENT_LENGTH": str(length),
+ "CONTENT_TYPE": 'multipart/form-data; boundary="%s"' % boundary,
+ }
+ )[1]
strict_eq(form, d)
stream.close()
def test_environ_builder_unicode_file_mix():
for use_tempfile in False, True:
- f = FileStorage(BytesIO(u'\N{SNOWMAN}'.encode('utf-8')),
- 'snowman.txt')
- d = MultiDict(dict(f=f, s=u'\N{SNOWMAN}'))
+ f = FileStorage(BytesIO(u"\N{SNOWMAN}".encode("utf-8")), "snowman.txt")
+ d = MultiDict(dict(f=f, s=u"\N{SNOWMAN}"))
stream, length, boundary = stream_encode_multipart(
- d, use_tempfile, threshold=150)
+ d, use_tempfile, threshold=150
+ )
assert isinstance(stream, BytesIO) != use_tempfile
- _, form, files = parse_form_data({
- 'wsgi.input': stream,
- 'CONTENT_LENGTH': str(length),
- 'CONTENT_TYPE': 'multipart/form-data; boundary="%s"' %
- boundary
- })
- strict_eq(form['s'], u'\N{SNOWMAN}')
- strict_eq(files['f'].name, 'f')
- strict_eq(files['f'].filename, u'snowman.txt')
- strict_eq(files['f'].read(),
- u'\N{SNOWMAN}'.encode('utf-8'))
+ _, form, files = parse_form_data(
+ {
+ "wsgi.input": stream,
+ "CONTENT_LENGTH": str(length),
+ "CONTENT_TYPE": 'multipart/form-data; boundary="%s"' % boundary,
+ }
+ )
+ strict_eq(form["s"], u"\N{SNOWMAN}")
+ strict_eq(files["f"].name, "f")
+ strict_eq(files["f"].filename, u"snowman.txt")
+ strict_eq(files["f"].read(), u"\N{SNOWMAN}".encode("utf-8"))
stream.close()
def test_create_environ():
- env = create_environ('/foo?bar=baz', 'http://example.org/')
+ env = create_environ("/foo?bar=baz", "http://example.org/")
expected = {
- 'wsgi.multiprocess': False,
- 'wsgi.version': (1, 0),
- 'wsgi.run_once': False,
- 'wsgi.errors': sys.stderr,
- 'wsgi.multithread': False,
- 'wsgi.url_scheme': 'http',
- 'SCRIPT_NAME': '',
- 'SERVER_NAME': 'example.org',
- 'REQUEST_METHOD': 'GET',
- 'HTTP_HOST': 'example.org',
- 'PATH_INFO': '/foo',
- 'SERVER_PORT': '80',
- 'SERVER_PROTOCOL': 'HTTP/1.1',
- 'QUERY_STRING': 'bar=baz'
+ "wsgi.multiprocess": False,
+ "wsgi.version": (1, 0),
+ "wsgi.run_once": False,
+ "wsgi.errors": sys.stderr,
+ "wsgi.multithread": False,
+ "wsgi.url_scheme": "http",
+ "SCRIPT_NAME": "",
+ "SERVER_NAME": "example.org",
+ "REQUEST_METHOD": "GET",
+ "HTTP_HOST": "example.org",
+ "PATH_INFO": "/foo",
+ "SERVER_PORT": "80",
+ "SERVER_PROTOCOL": "HTTP/1.1",
+ "QUERY_STRING": "bar=baz",
}
for key, value in iteritems(expected):
assert env[key] == value
- strict_eq(env['wsgi.input'].read(0), b'')
- strict_eq(create_environ('/foo', 'http://example.com/')['SCRIPT_NAME'], '')
+ strict_eq(env["wsgi.input"].read(0), b"")
+ strict_eq(create_environ("/foo", "http://example.com/")["SCRIPT_NAME"], "")
def test_create_environ_query_string_error():
with pytest.raises(ValueError):
- create_environ('/foo?bar=baz', query_string={'a': 'b'})
+ create_environ("/foo?bar=baz", query_string={"a": "b"})
def test_builder_from_environ():
@@ -341,7 +355,7 @@ def test_builder_from_environ():
base_url="https://example.com/base",
query_string={"name": "Werkzeug"},
data={"foo": "bar"},
- headers={"X-Foo": "bar"}
+ headers={"X-Foo": "bar"},
)
builder = EnvironBuilder.from_environ(environ)
try:
@@ -355,39 +369,38 @@ def test_file_closing():
closed = []
class SpecialInput(object):
-
def read(self, size):
- return ''
+ return ""
def close(self):
closed.append(self)
- create_environ(data={'foo': SpecialInput()})
+ create_environ(data={"foo": SpecialInput()})
strict_eq(len(closed), 1)
builder = EnvironBuilder()
- builder.files.add_file('blah', SpecialInput())
+ builder.files.add_file("blah", SpecialInput())
builder.close()
strict_eq(len(closed), 2)
def test_follow_redirect():
- env = create_environ('/', base_url='http://localhost')
+ env = create_environ("/", base_url="http://localhost")
c = Client(redirect_with_get_app)
appiter, code, headers = c.open(environ_overrides=env, follow_redirects=True)
- strict_eq(code, '200 OK')
- strict_eq(b''.join(appiter), b'current url: http://localhost/some/redirect/')
+ strict_eq(code, "200 OK")
+ strict_eq(b"".join(appiter), b"current url: http://localhost/some/redirect/")
# Test that the :cls:`Client` is aware of user defined response wrappers
c = Client(redirect_with_get_app, response_wrapper=BaseResponse)
- resp = c.get('/', follow_redirects=True)
+ resp = c.get("/", follow_redirects=True)
strict_eq(resp.status_code, 200)
- strict_eq(resp.data, b'current url: http://localhost/some/redirect/')
+ strict_eq(resp.data, b"current url: http://localhost/some/redirect/")
# test with URL other than '/' to make sure redirected URL's are correct
c = Client(redirect_with_get_app, response_wrapper=BaseResponse)
- resp = c.get('/first/request', follow_redirects=True)
+ resp = c.get("/first/request", follow_redirects=True)
strict_eq(resp.status_code, 200)
- strict_eq(resp.data, b'current url: http://localhost/some/redirect/')
+ strict_eq(resp.data, b"current url: http://localhost/some/redirect/")
def test_follow_local_redirect():
@@ -396,21 +409,20 @@ def test_follow_local_redirect():
def local_redirect_app(environ, start_response):
req = Request(environ)
- if '/from/location' in req.url:
- response = redirect('/to/location', Response=LocalResponse)
+ if "/from/location" in req.url:
+ response = redirect("/to/location", Response=LocalResponse)
else:
- response = Response('current path: %s' % req.path)
+ response = Response("current path: %s" % req.path)
return response(environ, start_response)
c = Client(local_redirect_app, response_wrapper=BaseResponse)
- resp = c.get('/from/location', follow_redirects=True)
+ resp = c.get("/from/location", follow_redirects=True)
strict_eq(resp.status_code, 200)
- strict_eq(resp.data, b'current path: /to/location')
+ strict_eq(resp.data, b"current path: /to/location")
@pytest.mark.parametrize(
- ("code", "keep"),
- ((302, False), (301, False), (307, True), (308, True)),
+ ("code", "keep"), ((302, False), (301, False), (307, True), (308, True))
)
def test_follow_redirect_body(code, keep):
@Request.application
@@ -430,42 +442,42 @@ def test_follow_redirect_body(code, keep):
c = Client(app, response_wrapper=BaseResponse)
response = c.post(
- "/",
- follow_redirects=True,
- data={"foo": "bar"},
- headers={"X-Foo": "bar"},
+ "/", follow_redirects=True, data={"foo": "bar"}, headers={"X-Foo": "bar"}
)
assert response.status_code == 200
assert response.data == b"current url: http://localhost/some/redirect/"
def test_follow_external_redirect():
- env = create_environ('/', base_url='http://localhost')
+ env = create_environ("/", base_url="http://localhost")
c = Client(external_redirect_demo_app)
- pytest.raises(RuntimeError, lambda:
- c.get(environ_overrides=env, follow_redirects=True))
+ pytest.raises(
+ RuntimeError, lambda: c.get(environ_overrides=env, follow_redirects=True)
+ )
def test_follow_external_redirect_on_same_subdomain():
- env = create_environ('/', base_url='http://example.com')
+ env = create_environ("/", base_url="http://example.com")
c = Client(external_subdomain_redirect_demo_app, allow_subdomain_redirects=True)
c.get(environ_overrides=env, follow_redirects=True)
# check that this does not work for real external domains
- env = create_environ('/', base_url='http://localhost')
- pytest.raises(RuntimeError, lambda:
- c.get(environ_overrides=env, follow_redirects=True))
+ env = create_environ("/", base_url="http://localhost")
+ pytest.raises(
+ RuntimeError, lambda: c.get(environ_overrides=env, follow_redirects=True)
+ )
# check that subdomain redirects fail if no `allow_subdomain_redirects` is applied
c = Client(external_subdomain_redirect_demo_app)
- pytest.raises(RuntimeError, lambda:
- c.get(environ_overrides=env, follow_redirects=True))
+ pytest.raises(
+ RuntimeError, lambda: c.get(environ_overrides=env, follow_redirects=True)
+ )
def test_follow_redirect_loop():
c = Client(redirect_loop_app, response_wrapper=BaseResponse)
with pytest.raises(ClientRedirectError):
- c.get('/', follow_redirects=True)
+ c.get("/", follow_redirects=True)
def test_follow_redirect_non_root_base_url():
@@ -477,7 +489,9 @@ def test_follow_redirect_non_root_base_url():
return Response(request.path)
c = Client(app, response_wrapper=Response)
- response = c.get("/redirect", base_url="http://localhost/other", follow_redirects=True)
+ response = c.get(
+ "/redirect", base_url="http://localhost/other", follow_redirects=True
+ )
assert response.data == b"/done"
@@ -508,89 +522,84 @@ def test_follow_redirect_exhaust_intermediate():
def test_path_info_script_name_unquoting():
def test_app(environ, start_response):
- start_response('200 OK', [('Content-Type', 'text/plain')])
- return [environ['PATH_INFO'] + '\n' + environ['SCRIPT_NAME']]
+ start_response("200 OK", [("Content-Type", "text/plain")])
+ return [environ["PATH_INFO"] + "\n" + environ["SCRIPT_NAME"]]
+
c = Client(test_app, response_wrapper=BaseResponse)
- resp = c.get('/foo%40bar')
- strict_eq(resp.data, b'/foo@bar\n')
+ resp = c.get("/foo%40bar")
+ strict_eq(resp.data, b"/foo@bar\n")
c = Client(test_app, response_wrapper=BaseResponse)
- resp = c.get('/foo%40bar', 'http://localhost/bar%40baz')
- strict_eq(resp.data, b'/foo@bar\n/bar@baz')
+ resp = c.get("/foo%40bar", "http://localhost/bar%40baz")
+ strict_eq(resp.data, b"/foo@bar\n/bar@baz")
def test_multi_value_submit():
c = Client(multi_value_post_app, response_wrapper=BaseResponse)
- data = {
- 'field': ['val1', 'val2']
- }
- resp = c.post('/', data=data)
+ data = {"field": ["val1", "val2"]}
+ resp = c.post("/", data=data)
strict_eq(resp.status_code, 200)
c = Client(multi_value_post_app, response_wrapper=BaseResponse)
- data = MultiDict({
- 'field': ['val1', 'val2']
- })
- resp = c.post('/', data=data)
+ data = MultiDict({"field": ["val1", "val2"]})
+ resp = c.post("/", data=data)
strict_eq(resp.status_code, 200)
def test_iri_support():
- b = EnvironBuilder(u'/föö-bar', base_url=u'http://☃.net/')
- strict_eq(b.path, '/f%C3%B6%C3%B6-bar')
- strict_eq(b.base_url, 'http://xn--n3h.net/')
+ b = EnvironBuilder(u"/föö-bar", base_url=u"http://☃.net/")
+ strict_eq(b.path, "/f%C3%B6%C3%B6-bar")
+ strict_eq(b.base_url, "http://xn--n3h.net/")
-@pytest.mark.parametrize('buffered', (True, False))
-@pytest.mark.parametrize('iterable', (True, False))
+@pytest.mark.parametrize("buffered", (True, False))
+@pytest.mark.parametrize("iterable", (True, False))
def test_run_wsgi_apps(buffered, iterable):
leaked_data = []
def simple_app(environ, start_response):
- start_response('200 OK', [('Content-Type', 'text/html')])
- return ['Hello World!']
+ start_response("200 OK", [("Content-Type", "text/html")])
+ return ["Hello World!"]
def yielding_app(environ, start_response):
- start_response('200 OK', [('Content-Type', 'text/html')])
- yield 'Hello '
- yield 'World!'
+ start_response("200 OK", [("Content-Type", "text/html")])
+ yield "Hello "
+ yield "World!"
def late_start_response(environ, start_response):
- yield 'Hello '
- yield 'World'
- start_response('200 OK', [('Content-Type', 'text/html')])
- yield '!'
+ yield "Hello "
+ yield "World"
+ start_response("200 OK", [("Content-Type", "text/html")])
+ yield "!"
def depends_on_close(environ, start_response):
- leaked_data.append('harhar')
- start_response('200 OK', [('Content-Type', 'text/html')])
+ leaked_data.append("harhar")
+ start_response("200 OK", [("Content-Type", "text/html")])
class Rv(object):
-
def __iter__(self):
- yield 'Hello '
- yield 'World'
- yield '!'
+ yield "Hello "
+ yield "World"
+ yield "!"
def close(self):
- assert leaked_data.pop() == 'harhar'
+ assert leaked_data.pop() == "harhar"
return Rv()
- for app in (simple_app, yielding_app, late_start_response,
- depends_on_close):
+ for app in (simple_app, yielding_app, late_start_response, depends_on_close):
if iterable:
app = iterable_middleware(app)
app_iter, status, headers = run_wsgi_app(app, {}, buffered=buffered)
- strict_eq(status, '200 OK')
- strict_eq(list(headers), [('Content-Type', 'text/html')])
- strict_eq(''.join(app_iter), 'Hello World!')
+ strict_eq(status, "200 OK")
+ strict_eq(list(headers), [("Content-Type", "text/html")])
+ strict_eq("".join(app_iter), "Hello World!")
- if hasattr(app_iter, 'close'):
+ if hasattr(app_iter, "close"):
app_iter.close()
assert not leaked_data
-@pytest.mark.parametrize('buffered', (True, False))
-@pytest.mark.parametrize('iterable', (True, False))
+@pytest.mark.parametrize("buffered", (True, False))
+@pytest.mark.parametrize("iterable", (True, False))
def test_lazy_start_response_empty_response_app(buffered, iterable):
@implements_iterator
class app:
@@ -601,15 +610,15 @@ def test_lazy_start_response_empty_response_app(buffered, iterable):
return self
def __next__(self):
- self.start_response('200 OK', [('Content-Type', 'text/html')])
+ self.start_response("200 OK", [("Content-Type", "text/html")])
raise StopIteration
if iterable:
app = iterable_middleware(app)
app_iter, status, headers = run_wsgi_app(app, {}, buffered=buffered)
- strict_eq(status, '200 OK')
- strict_eq(list(headers), [('Content-Type', 'text/html')])
- strict_eq(''.join(app_iter), '')
+ strict_eq(status, "200 OK")
+ strict_eq(list(headers), [("Content-Type", "text/html")])
+ strict_eq("".join(app_iter), "")
def test_run_wsgi_app_closing_iterator():
@@ -617,7 +626,6 @@ def test_run_wsgi_app_closing_iterator():
@implements_iterator
class CloseIter(object):
-
def __init__(self):
self.iterated = False
@@ -631,39 +639,41 @@ def test_run_wsgi_app_closing_iterator():
if self.iterated:
raise StopIteration()
self.iterated = True
- return 'bar'
+ return "bar"
def bar(environ, start_response):
- start_response('200 OK', [('Content-Type', 'text/plain')])
+ start_response("200 OK", [("Content-Type", "text/plain")])
return CloseIter()
app_iter, status, headers = run_wsgi_app(bar, {})
- assert status == '200 OK'
- assert list(headers) == [('Content-Type', 'text/plain')]
- assert next(app_iter) == 'bar'
+ assert status == "200 OK"
+ assert list(headers) == [("Content-Type", "text/plain")]
+ assert next(app_iter) == "bar"
pytest.raises(StopIteration, partial(next, app_iter))
app_iter.close()
- assert run_wsgi_app(bar, {}, True)[0] == ['bar']
+ assert run_wsgi_app(bar, {}, True)[0] == ["bar"]
assert len(got_close) == 2
def iterable_middleware(app):
- '''Guarantee that the app returns an iterable'''
+ """Guarantee that the app returns an iterable"""
+
def inner(environ, start_response):
rv = app(environ, start_response)
class Iterable(object):
-
def __iter__(self):
return iter(rv)
- if hasattr(rv, 'close'):
+ if hasattr(rv, "close"):
+
def close(self):
rv.close()
return Iterable()
+
return inner
@@ -671,15 +681,17 @@ def test_multiple_cookies():
@Request.application
def test_app(request):
response = Response(repr(sorted(request.cookies.items())))
- response.set_cookie(u'test1', b'foo')
- response.set_cookie(u'test2', b'bar')
+ response.set_cookie(u"test1", b"foo")
+ response.set_cookie(u"test2", b"bar")
return response
+
client = Client(test_app, Response)
- resp = client.get('/')
- strict_eq(resp.data, b'[]')
- resp = client.get('/')
- strict_eq(resp.data,
- to_bytes(repr([('test1', u'foo'), ('test2', u'bar')]), 'ascii'))
+ resp = client.get("/")
+ strict_eq(resp.data, b"[]")
+ resp = client.get("/")
+ strict_eq(
+ resp.data, to_bytes(repr([("test1", u"foo"), ("test2", u"bar")]), "ascii")
+ )
def test_correct_open_invocation_on_redirect():
@@ -688,58 +700,59 @@ def test_correct_open_invocation_on_redirect():
def open(self, *args, **kwargs):
self.counter += 1
- env = kwargs.setdefault('environ_overrides', {})
- env['werkzeug._foo'] = self.counter
+ env = kwargs.setdefault("environ_overrides", {})
+ env["werkzeug._foo"] = self.counter
return Client.open(self, *args, **kwargs)
@Request.application
def test_app(request):
- return Response(str(request.environ['werkzeug._foo']))
+ return Response(str(request.environ["werkzeug._foo"]))
c = MyClient(test_app, response_wrapper=Response)
- strict_eq(c.get('/').data, b'1')
- strict_eq(c.get('/').data, b'2')
- strict_eq(c.get('/').data, b'3')
+ strict_eq(c.get("/").data, b"1")
+ strict_eq(c.get("/").data, b"2")
+ strict_eq(c.get("/").data, b"3")
def test_correct_encoding():
- req = Request.from_values(u'/\N{SNOWMAN}', u'http://example.com/foo')
- strict_eq(req.script_root, u'/foo')
- strict_eq(req.path, u'/\N{SNOWMAN}')
+ req = Request.from_values(u"/\N{SNOWMAN}", u"http://example.com/foo")
+ strict_eq(req.script_root, u"/foo")
+ strict_eq(req.path, u"/\N{SNOWMAN}")
def test_full_url_requests_with_args():
- base = 'http://example.com/'
+ base = "http://example.com/"
@Request.application
def test_app(request):
- return Response(request.args['x'])
+ return Response(request.args["x"])
+
client = Client(test_app, Response)
- resp = client.get('/?x=42', base)
- strict_eq(resp.data, b'42')
- resp = client.get('http://www.example.com/?x=23', base)
- strict_eq(resp.data, b'23')
+ resp = client.get("/?x=42", base)
+ strict_eq(resp.data, b"42")
+ resp = client.get("http://www.example.com/?x=23", base)
+ strict_eq(resp.data, b"23")
def test_delete_requests_with_form():
@Request.application
def test_app(request):
- return Response(request.form.get('x', None))
+ return Response(request.form.get("x", None))
client = Client(test_app, Response)
- resp = client.delete('/', data={'x': 42})
- strict_eq(resp.data, b'42')
+ resp = client.delete("/", data={"x": 42})
+ strict_eq(resp.data, b"42")
def test_post_with_file_descriptor(tmpdir):
c = Client(Response(), response_wrapper=Response)
- f = tmpdir.join('some-file.txt')
- f.write('foo')
- with open(f.strpath, mode='rt') as data:
- resp = c.post('/', data=data)
+ f = tmpdir.join("some-file.txt")
+ f.write("foo")
+ with open(f.strpath, mode="rt") as data:
+ resp = c.post("/", data=data)
strict_eq(resp.status_code, 200)
- with open(f.strpath, mode='rb') as data:
- resp = c.post('/', data=data)
+ with open(f.strpath, mode="rb") as data:
+ resp = c.post("/", data=data)
strict_eq(resp.status_code, 200)
@@ -747,13 +760,14 @@ def test_content_type():
@Request.application
def test_app(request):
return Response(request.content_type)
+
client = Client(test_app, Response)
- resp = client.get('/', data=b'testing', mimetype='text/css')
- strict_eq(resp.data, b'text/css; charset=utf-8')
+ resp = client.get("/", data=b"testing", mimetype="text/css")
+ strict_eq(resp.data, b"text/css; charset=utf-8")
- resp = client.get('/', data=b'testing', mimetype='application/octet-stream')
- strict_eq(resp.data, b'application/octet-stream')
+ resp = client.get("/", data=b"testing", mimetype="application/octet-stream")
+ strict_eq(resp.data, b"application/octet-stream")
def test_raw_request_uri():
diff --git a/tests/test_urls.py b/tests/test_urls.py
index ae3db554..71b80c33 100644
--- a/tests/test_urls.py
+++ b/tests/test_urls.py
@@ -10,394 +10,426 @@
"""
import pytest
-from tests import strict_eq
-
-from werkzeug.datastructures import OrderedMultiDict
+from . import strict_eq
from werkzeug import urls
-from werkzeug._compat import text_type, NativeStringIO, BytesIO
+from werkzeug._compat import BytesIO
+from werkzeug._compat import NativeStringIO
+from werkzeug._compat import text_type
+from werkzeug.datastructures import OrderedMultiDict
def test_parsing():
- url = urls.url_parse('http://anon:hunter2@[2001:db8:0:1]:80/a/b/c')
- assert url.netloc == 'anon:hunter2@[2001:db8:0:1]:80'
- assert url.username == 'anon'
- assert url.password == 'hunter2'
+ url = urls.url_parse("http://anon:hunter2@[2001:db8:0:1]:80/a/b/c")
+ assert url.netloc == "anon:hunter2@[2001:db8:0:1]:80"
+ assert url.username == "anon"
+ assert url.password == "hunter2"
assert url.port == 80
- assert url.ascii_host == '2001:db8:0:1'
+ assert url.ascii_host == "2001:db8:0:1"
assert url.get_file_location() == (None, None) # no file scheme
-@pytest.mark.parametrize('implicit_format', (True, False))
-@pytest.mark.parametrize('localhost', ('127.0.0.1', '::1', 'localhost'))
+@pytest.mark.parametrize("implicit_format", (True, False))
+@pytest.mark.parametrize("localhost", ("127.0.0.1", "::1", "localhost"))
def test_fileurl_parsing_windows(implicit_format, localhost, monkeypatch):
if implicit_format:
pathformat = None
- monkeypatch.setattr('os.name', 'nt')
+ monkeypatch.setattr("os.name", "nt")
else:
- pathformat = 'windows'
- monkeypatch.delattr('os.name') # just to make sure it won't get used
+ pathformat = "windows"
+ monkeypatch.delattr("os.name") # just to make sure it won't get used
- url = urls.url_parse('file:///C:/Documents and Settings/Foobar/stuff.txt')
- assert url.netloc == ''
- assert url.scheme == 'file'
- assert url.get_file_location(pathformat) == \
- (None, r'C:\Documents and Settings\Foobar\stuff.txt')
+ url = urls.url_parse("file:///C:/Documents and Settings/Foobar/stuff.txt")
+ assert url.netloc == ""
+ assert url.scheme == "file"
+ assert url.get_file_location(pathformat) == (
+ None,
+ r"C:\Documents and Settings\Foobar\stuff.txt",
+ )
- url = urls.url_parse('file://///server.tld/file.txt')
- assert url.get_file_location(pathformat) == ('server.tld', r'file.txt')
+ url = urls.url_parse("file://///server.tld/file.txt")
+ assert url.get_file_location(pathformat) == ("server.tld", r"file.txt")
- url = urls.url_parse('file://///server.tld')
- assert url.get_file_location(pathformat) == ('server.tld', '')
+ url = urls.url_parse("file://///server.tld")
+ assert url.get_file_location(pathformat) == ("server.tld", "")
- url = urls.url_parse('file://///%s' % localhost)
- assert url.get_file_location(pathformat) == (None, '')
+ url = urls.url_parse("file://///%s" % localhost)
+ assert url.get_file_location(pathformat) == (None, "")
- url = urls.url_parse('file://///%s/file.txt' % localhost)
- assert url.get_file_location(pathformat) == (None, r'file.txt')
+ url = urls.url_parse("file://///%s/file.txt" % localhost)
+ assert url.get_file_location(pathformat) == (None, r"file.txt")
def test_replace():
- url = urls.url_parse('http://de.wikipedia.org/wiki/Troll')
- strict_eq(url.replace(query='foo=bar'),
- urls.url_parse('http://de.wikipedia.org/wiki/Troll?foo=bar'))
- strict_eq(url.replace(scheme='https'),
- urls.url_parse('https://de.wikipedia.org/wiki/Troll'))
+ url = urls.url_parse("http://de.wikipedia.org/wiki/Troll")
+ strict_eq(
+ url.replace(query="foo=bar"),
+ urls.url_parse("http://de.wikipedia.org/wiki/Troll?foo=bar"),
+ )
+ strict_eq(
+ url.replace(scheme="https"),
+ urls.url_parse("https://de.wikipedia.org/wiki/Troll"),
+ )
def test_quoting():
- strict_eq(urls.url_quote(u'\xf6\xe4\xfc'), '%C3%B6%C3%A4%C3%BC')
+ strict_eq(urls.url_quote(u"\xf6\xe4\xfc"), "%C3%B6%C3%A4%C3%BC")
strict_eq(urls.url_unquote(urls.url_quote(u'#%="\xf6')), u'#%="\xf6')
- strict_eq(urls.url_quote_plus('foo bar'), 'foo+bar')
- strict_eq(urls.url_unquote_plus('foo+bar'), u'foo bar')
- strict_eq(urls.url_quote_plus('foo+bar'), 'foo%2Bbar')
- strict_eq(urls.url_unquote_plus('foo%2Bbar'), u'foo+bar')
- strict_eq(urls.url_encode({b'a': None, b'b': b'foo bar'}), 'b=foo+bar')
- strict_eq(urls.url_encode({u'a': None, u'b': u'foo bar'}), 'b=foo+bar')
- strict_eq(urls.url_fix(u'http://de.wikipedia.org/wiki/Elf (Begriffsklärung)'),
- 'http://de.wikipedia.org/wiki/Elf%20(Begriffskl%C3%A4rung)')
- strict_eq(urls.url_quote_plus(42), '42')
- strict_eq(urls.url_quote(b'\xff'), '%FF')
+ strict_eq(urls.url_quote_plus("foo bar"), "foo+bar")
+ strict_eq(urls.url_unquote_plus("foo+bar"), u"foo bar")
+ strict_eq(urls.url_quote_plus("foo+bar"), "foo%2Bbar")
+ strict_eq(urls.url_unquote_plus("foo%2Bbar"), u"foo+bar")
+ strict_eq(urls.url_encode({b"a": None, b"b": b"foo bar"}), "b=foo+bar")
+ strict_eq(urls.url_encode({u"a": None, u"b": u"foo bar"}), "b=foo+bar")
+ strict_eq(
+ urls.url_fix(u"http://de.wikipedia.org/wiki/Elf (Begriffsklärung)"),
+ "http://de.wikipedia.org/wiki/Elf%20(Begriffskl%C3%A4rung)",
+ )
+ strict_eq(urls.url_quote_plus(42), "42")
+ strict_eq(urls.url_quote(b"\xff"), "%FF")
def test_bytes_unquoting():
- strict_eq(urls.url_unquote(urls.url_quote(
- u'#%="\xf6', charset='latin1'), charset=None), b'#%="\xf6')
+ strict_eq(
+ urls.url_unquote(urls.url_quote(u'#%="\xf6', charset="latin1"), charset=None),
+ b'#%="\xf6',
+ )
def test_url_decoding():
- x = urls.url_decode(b'foo=42&bar=23&uni=H%C3%A4nsel')
- strict_eq(x['foo'], u'42')
- strict_eq(x['bar'], u'23')
- strict_eq(x['uni'], u'Hänsel')
+ x = urls.url_decode(b"foo=42&bar=23&uni=H%C3%A4nsel")
+ strict_eq(x["foo"], u"42")
+ strict_eq(x["bar"], u"23")
+ strict_eq(x["uni"], u"Hänsel")
- x = urls.url_decode(b'foo=42;bar=23;uni=H%C3%A4nsel', separator=b';')
- strict_eq(x['foo'], u'42')
- strict_eq(x['bar'], u'23')
- strict_eq(x['uni'], u'Hänsel')
+ x = urls.url_decode(b"foo=42;bar=23;uni=H%C3%A4nsel", separator=b";")
+ strict_eq(x["foo"], u"42")
+ strict_eq(x["bar"], u"23")
+ strict_eq(x["uni"], u"Hänsel")
- x = urls.url_decode(b'%C3%9Ch=H%C3%A4nsel', decode_keys=True)
- strict_eq(x[u'Üh'], u'Hänsel')
+ x = urls.url_decode(b"%C3%9Ch=H%C3%A4nsel", decode_keys=True)
+ strict_eq(x[u"Üh"], u"Hänsel")
def test_url_bytes_decoding():
- x = urls.url_decode(b'foo=42&bar=23&uni=H%C3%A4nsel', charset=None)
- strict_eq(x[b'foo'], b'42')
- strict_eq(x[b'bar'], b'23')
- strict_eq(x[b'uni'], u'Hänsel'.encode('utf-8'))
+ x = urls.url_decode(b"foo=42&bar=23&uni=H%C3%A4nsel", charset=None)
+ strict_eq(x[b"foo"], b"42")
+ strict_eq(x[b"bar"], b"23")
+ strict_eq(x[b"uni"], u"Hänsel".encode("utf-8"))
def test_streamed_url_decoding():
- item1 = u'a' * 100000
- item2 = u'b' * 400
- string = ('a=%s&b=%s&c=%s' % (item1, item2, item2)).encode('ascii')
- gen = urls.url_decode_stream(BytesIO(string), limit=len(string),
- return_iterator=True)
- strict_eq(next(gen), ('a', item1))
- strict_eq(next(gen), ('b', item2))
- strict_eq(next(gen), ('c', item2))
+ item1 = u"a" * 100000
+ item2 = u"b" * 400
+ string = ("a=%s&b=%s&c=%s" % (item1, item2, item2)).encode("ascii")
+ gen = urls.url_decode_stream(
+ BytesIO(string), limit=len(string), return_iterator=True
+ )
+ strict_eq(next(gen), ("a", item1))
+ strict_eq(next(gen), ("b", item2))
+ strict_eq(next(gen), ("c", item2))
pytest.raises(StopIteration, lambda: next(gen))
def test_stream_decoding_string_fails():
- pytest.raises(TypeError, urls.url_decode_stream, 'testing')
+ pytest.raises(TypeError, urls.url_decode_stream, "testing")
def test_url_encoding():
- strict_eq(urls.url_encode({'foo': 'bar 45'}), 'foo=bar+45')
- d = {'foo': 1, 'bar': 23, 'blah': u'Hänsel'}
- strict_eq(urls.url_encode(d, sort=True), 'bar=23&blah=H%C3%A4nsel&foo=1')
- strict_eq(urls.url_encode(d, sort=True, separator=u';'), 'bar=23;blah=H%C3%A4nsel;foo=1')
+ strict_eq(urls.url_encode({"foo": "bar 45"}), "foo=bar+45")
+ d = {"foo": 1, "bar": 23, "blah": u"Hänsel"}
+ strict_eq(urls.url_encode(d, sort=True), "bar=23&blah=H%C3%A4nsel&foo=1")
+ strict_eq(
+ urls.url_encode(d, sort=True, separator=u";"), "bar=23;blah=H%C3%A4nsel;foo=1"
+ )
def test_sorted_url_encode():
- strict_eq(urls.url_encode({u"a": 42, u"b": 23, 1: 1, 2: 2},
- sort=True, key=lambda i: text_type(i[0])), '1=1&2=2&a=42&b=23')
- strict_eq(urls.url_encode({u'A': 1, u'a': 2, u'B': 3, 'b': 4}, sort=True,
- key=lambda x: x[0].lower() + x[0]), 'A=1&a=2&B=3&b=4')
+ strict_eq(
+ urls.url_encode(
+ {u"a": 42, u"b": 23, 1: 1, 2: 2}, sort=True, key=lambda i: text_type(i[0])
+ ),
+ "1=1&2=2&a=42&b=23",
+ )
+ strict_eq(
+ urls.url_encode(
+ {u"A": 1, u"a": 2, u"B": 3, "b": 4},
+ sort=True,
+ key=lambda x: x[0].lower() + x[0],
+ ),
+ "A=1&a=2&B=3&b=4",
+ )
def test_streamed_url_encoding():
out = NativeStringIO()
- urls.url_encode_stream({'foo': 'bar 45'}, out)
- strict_eq(out.getvalue(), 'foo=bar+45')
+ urls.url_encode_stream({"foo": "bar 45"}, out)
+ strict_eq(out.getvalue(), "foo=bar+45")
- d = {'foo': 1, 'bar': 23, 'blah': u'Hänsel'}
+ d = {"foo": 1, "bar": 23, "blah": u"Hänsel"}
out = NativeStringIO()
urls.url_encode_stream(d, out, sort=True)
- strict_eq(out.getvalue(), 'bar=23&blah=H%C3%A4nsel&foo=1')
+ strict_eq(out.getvalue(), "bar=23&blah=H%C3%A4nsel&foo=1")
out = NativeStringIO()
- urls.url_encode_stream(d, out, sort=True, separator=u';')
- strict_eq(out.getvalue(), 'bar=23;blah=H%C3%A4nsel;foo=1')
+ urls.url_encode_stream(d, out, sort=True, separator=u";")
+ strict_eq(out.getvalue(), "bar=23;blah=H%C3%A4nsel;foo=1")
gen = urls.url_encode_stream(d, sort=True)
- strict_eq(next(gen), 'bar=23')
- strict_eq(next(gen), 'blah=H%C3%A4nsel')
- strict_eq(next(gen), 'foo=1')
+ strict_eq(next(gen), "bar=23")
+ strict_eq(next(gen), "blah=H%C3%A4nsel")
+ strict_eq(next(gen), "foo=1")
pytest.raises(StopIteration, lambda: next(gen))
def test_url_fixing():
- x = urls.url_fix(u'http://de.wikipedia.org/wiki/Elf (Begriffskl\xe4rung)')
- assert x == 'http://de.wikipedia.org/wiki/Elf%20(Begriffskl%C3%A4rung)'
+ x = urls.url_fix(u"http://de.wikipedia.org/wiki/Elf (Begriffskl\xe4rung)")
+ assert x == "http://de.wikipedia.org/wiki/Elf%20(Begriffskl%C3%A4rung)"
x = urls.url_fix("http://just.a.test/$-_.+!*'(),")
assert x == "http://just.a.test/$-_.+!*'(),"
- x = urls.url_fix('http://höhöhö.at/höhöhö/hähähä')
- assert x == r'http://xn--hhh-snabb.at/h%C3%B6h%C3%B6h%C3%B6/h%C3%A4h%C3%A4h%C3%A4'
+ x = urls.url_fix("http://höhöhö.at/höhöhö/hähähä")
+ assert x == r"http://xn--hhh-snabb.at/h%C3%B6h%C3%B6h%C3%B6/h%C3%A4h%C3%A4h%C3%A4"
def test_url_fixing_filepaths():
- x = urls.url_fix(r'file://C:\Users\Administrator\My Documents\ÑÈáÇíí')
- assert x == (r'file:///C%3A/Users/Administrator/My%20Documents/'
- r'%C3%91%C3%88%C3%A1%C3%87%C3%AD%C3%AD')
+ x = urls.url_fix(r"file://C:\Users\Administrator\My Documents\ÑÈáÇíí")
+ assert x == (
+ r"file:///C%3A/Users/Administrator/My%20Documents/"
+ r"%C3%91%C3%88%C3%A1%C3%87%C3%AD%C3%AD"
+ )
- a = urls.url_fix(r'file:/C:/')
- b = urls.url_fix(r'file://C:/')
- c = urls.url_fix(r'file:///C:/')
- assert a == b == c == r'file:///C%3A/'
+ a = urls.url_fix(r"file:/C:/")
+ b = urls.url_fix(r"file://C:/")
+ c = urls.url_fix(r"file:///C:/")
+ assert a == b == c == r"file:///C%3A/"
- x = urls.url_fix(r'file://host/sub/path')
- assert x == r'file://host/sub/path'
+ x = urls.url_fix(r"file://host/sub/path")
+ assert x == r"file://host/sub/path"
- x = urls.url_fix(r'file:///')
- assert x == r'file:///'
+ x = urls.url_fix(r"file:///")
+ assert x == r"file:///"
def test_url_fixing_qs():
- x = urls.url_fix(b'http://example.com/?foo=%2f%2f')
- assert x == 'http://example.com/?foo=%2f%2f'
+ x = urls.url_fix(b"http://example.com/?foo=%2f%2f")
+ assert x == "http://example.com/?foo=%2f%2f"
- x = urls.url_fix('http://acronyms.thefreedictionary.com/'
- 'Algebraic+Methods+of+Solving+the+Schr%C3%B6dinger+Equation')
- assert x == ('http://acronyms.thefreedictionary.com/'
- 'Algebraic+Methods+of+Solving+the+Schr%C3%B6dinger+Equation')
+ x = urls.url_fix(
+ "http://acronyms.thefreedictionary.com/"
+ "Algebraic+Methods+of+Solving+the+Schr%C3%B6dinger+Equation"
+ )
+ assert x == (
+ "http://acronyms.thefreedictionary.com/"
+ "Algebraic+Methods+of+Solving+the+Schr%C3%B6dinger+Equation"
+ )
def test_iri_support():
- strict_eq(urls.uri_to_iri('http://xn--n3h.net/'),
- u'http://\u2603.net/')
+ strict_eq(urls.uri_to_iri("http://xn--n3h.net/"), u"http://\u2603.net/")
strict_eq(
- urls.uri_to_iri(b'http://%C3%BCser:p%C3%A4ssword@xn--n3h.net/p%C3%A5th'),
- u'http://\xfcser:p\xe4ssword@\u2603.net/p\xe5th')
- strict_eq(urls.iri_to_uri(u'http://☃.net/'), 'http://xn--n3h.net/')
+ urls.uri_to_iri(b"http://%C3%BCser:p%C3%A4ssword@xn--n3h.net/p%C3%A5th"),
+ u"http://\xfcser:p\xe4ssword@\u2603.net/p\xe5th",
+ )
+ strict_eq(urls.iri_to_uri(u"http://☃.net/"), "http://xn--n3h.net/")
strict_eq(
- urls.iri_to_uri(u'http://üser:pässword@☃.net/påth'),
- 'http://%C3%BCser:p%C3%A4ssword@xn--n3h.net/p%C3%A5th')
+ urls.iri_to_uri(u"http://üser:pässword@☃.net/påth"),
+ "http://%C3%BCser:p%C3%A4ssword@xn--n3h.net/p%C3%A5th",
+ )
- strict_eq(urls.uri_to_iri('http://test.com/%3Fmeh?foo=%26%2F'),
- u'http://test.com/%3Fmeh?foo=%26%2F')
+ strict_eq(
+ urls.uri_to_iri("http://test.com/%3Fmeh?foo=%26%2F"),
+ u"http://test.com/%3Fmeh?foo=%26%2F",
+ )
# this should work as well, might break on 2.4 because of a broken
# idna codec
- strict_eq(urls.uri_to_iri(b'/foo'), u'/foo')
- strict_eq(urls.iri_to_uri(u'/foo'), '/foo')
+ strict_eq(urls.uri_to_iri(b"/foo"), u"/foo")
+ strict_eq(urls.iri_to_uri(u"/foo"), "/foo")
- strict_eq(urls.iri_to_uri(u'http://föö.com:8080/bam/baz'),
- 'http://xn--f-1gaa.com:8080/bam/baz')
+ strict_eq(
+ urls.iri_to_uri(u"http://föö.com:8080/bam/baz"),
+ "http://xn--f-1gaa.com:8080/bam/baz",
+ )
def test_iri_safe_conversion():
- strict_eq(urls.iri_to_uri(u'magnet:?foo=bar'),
- 'magnet:?foo=bar')
- strict_eq(urls.iri_to_uri(u'itms-service://?foo=bar'),
- 'itms-service:?foo=bar')
- strict_eq(urls.iri_to_uri(u'itms-service://?foo=bar',
- safe_conversion=True),
- 'itms-service://?foo=bar')
+ strict_eq(urls.iri_to_uri(u"magnet:?foo=bar"), "magnet:?foo=bar")
+ strict_eq(urls.iri_to_uri(u"itms-service://?foo=bar"), "itms-service:?foo=bar")
+ strict_eq(
+ urls.iri_to_uri(u"itms-service://?foo=bar", safe_conversion=True),
+ "itms-service://?foo=bar",
+ )
def test_iri_safe_quoting():
- uri = 'http://xn--f-1gaa.com/%2F%25?q=%C3%B6&x=%3D%25#%25'
- iri = u'http://föö.com/%2F%25?q=ö&x=%3D%25#%25'
+ uri = "http://xn--f-1gaa.com/%2F%25?q=%C3%B6&x=%3D%25#%25"
+ iri = u"http://föö.com/%2F%25?q=ö&x=%3D%25#%25"
strict_eq(urls.uri_to_iri(uri), iri)
strict_eq(urls.iri_to_uri(urls.uri_to_iri(uri)), uri)
def test_ordered_multidict_encoding():
d = OrderedMultiDict()
- d.add('foo', 1)
- d.add('foo', 2)
- d.add('foo', 3)
- d.add('bar', 0)
- d.add('foo', 4)
- assert urls.url_encode(d) == 'foo=1&foo=2&foo=3&bar=0&foo=4'
+ d.add("foo", 1)
+ d.add("foo", 2)
+ d.add("foo", 3)
+ d.add("bar", 0)
+ d.add("foo", 4)
+ assert urls.url_encode(d) == "foo=1&foo=2&foo=3&bar=0&foo=4"
def test_multidict_encoding():
d = OrderedMultiDict()
- d.add('2013-10-10T23:26:05.657975+0000', '2013-10-10T23:26:05.657975+0000')
- assert urls.url_encode(
- d) == '2013-10-10T23%3A26%3A05.657975%2B0000=2013-10-10T23%3A26%3A05.657975%2B0000'
+ d.add("2013-10-10T23:26:05.657975+0000", "2013-10-10T23:26:05.657975+0000")
+ assert (
+ urls.url_encode(d)
+ == "2013-10-10T23%3A26%3A05.657975%2B0000=2013-10-10T23%3A26%3A05.657975%2B0000"
+ )
def test_href():
- x = urls.Href('http://www.example.com/')
- strict_eq(x(u'foo'), 'http://www.example.com/foo')
- strict_eq(x.foo(u'bar'), 'http://www.example.com/foo/bar')
- strict_eq(x.foo(u'bar', x=42), 'http://www.example.com/foo/bar?x=42')
- strict_eq(x.foo(u'bar', class_=42), 'http://www.example.com/foo/bar?class=42')
- strict_eq(x.foo(u'bar', {u'class': 42}), 'http://www.example.com/foo/bar?class=42')
+ x = urls.Href("http://www.example.com/")
+ strict_eq(x(u"foo"), "http://www.example.com/foo")
+ strict_eq(x.foo(u"bar"), "http://www.example.com/foo/bar")
+ strict_eq(x.foo(u"bar", x=42), "http://www.example.com/foo/bar?x=42")
+ strict_eq(x.foo(u"bar", class_=42), "http://www.example.com/foo/bar?class=42")
+ strict_eq(x.foo(u"bar", {u"class": 42}), "http://www.example.com/foo/bar?class=42")
pytest.raises(AttributeError, lambda: x.__blah__)
- x = urls.Href('blah')
- strict_eq(x.foo(u'bar'), 'blah/foo/bar')
+ x = urls.Href("blah")
+ strict_eq(x.foo(u"bar"), "blah/foo/bar")
pytest.raises(TypeError, x.foo, {u"foo": 23}, x=42)
- x = urls.Href('')
- strict_eq(x('foo'), 'foo')
+ x = urls.Href("")
+ strict_eq(x("foo"), "foo")
def test_href_url_join():
- x = urls.Href(u'test')
- assert x(u'foo:bar') == u'test/foo:bar'
- assert x(u'http://example.com/') == u'test/http://example.com/'
- assert x.a() == u'test/a'
+ x = urls.Href(u"test")
+ assert x(u"foo:bar") == u"test/foo:bar"
+ assert x(u"http://example.com/") == u"test/http://example.com/"
+ assert x.a() == u"test/a"
def test_href_past_root():
- base_href = urls.Href('http://www.blagga.com/1/2/3')
- strict_eq(base_href('../foo'), 'http://www.blagga.com/1/2/foo')
- strict_eq(base_href('../../foo'), 'http://www.blagga.com/1/foo')
- strict_eq(base_href('../../../foo'), 'http://www.blagga.com/foo')
- strict_eq(base_href('../../../../foo'), 'http://www.blagga.com/foo')
- strict_eq(base_href('../../../../../foo'), 'http://www.blagga.com/foo')
- strict_eq(base_href('../../../../../../foo'), 'http://www.blagga.com/foo')
+ base_href = urls.Href("http://www.blagga.com/1/2/3")
+ strict_eq(base_href("../foo"), "http://www.blagga.com/1/2/foo")
+ strict_eq(base_href("../../foo"), "http://www.blagga.com/1/foo")
+ strict_eq(base_href("../../../foo"), "http://www.blagga.com/foo")
+ strict_eq(base_href("../../../../foo"), "http://www.blagga.com/foo")
+ strict_eq(base_href("../../../../../foo"), "http://www.blagga.com/foo")
+ strict_eq(base_href("../../../../../../foo"), "http://www.blagga.com/foo")
def test_url_unquote_plus_unicode():
# was broken in 0.6
- strict_eq(urls.url_unquote_plus(u'\x6d'), u'\x6d')
- assert type(urls.url_unquote_plus(u'\x6d')) is text_type
+ strict_eq(urls.url_unquote_plus(u"\x6d"), u"\x6d")
+ assert type(urls.url_unquote_plus(u"\x6d")) is text_type
def test_quoting_of_local_urls():
- rv = urls.iri_to_uri(u'/foo\x8f')
- strict_eq(rv, '/foo%C2%8F')
+ rv = urls.iri_to_uri(u"/foo\x8f")
+ strict_eq(rv, "/foo%C2%8F")
assert type(rv) is str
def test_url_attributes():
- rv = urls.url_parse('http://foo%3a:bar%3a@[::1]:80/123?x=y#frag')
- strict_eq(rv.scheme, 'http')
- strict_eq(rv.auth, 'foo%3a:bar%3a')
- strict_eq(rv.username, u'foo:')
- strict_eq(rv.password, u'bar:')
- strict_eq(rv.raw_username, 'foo%3a')
- strict_eq(rv.raw_password, 'bar%3a')
- strict_eq(rv.host, '::1')
+ rv = urls.url_parse("http://foo%3a:bar%3a@[::1]:80/123?x=y#frag")
+ strict_eq(rv.scheme, "http")
+ strict_eq(rv.auth, "foo%3a:bar%3a")
+ strict_eq(rv.username, u"foo:")
+ strict_eq(rv.password, u"bar:")
+ strict_eq(rv.raw_username, "foo%3a")
+ strict_eq(rv.raw_password, "bar%3a")
+ strict_eq(rv.host, "::1")
assert rv.port == 80
- strict_eq(rv.path, '/123')
- strict_eq(rv.query, 'x=y')
- strict_eq(rv.fragment, 'frag')
+ strict_eq(rv.path, "/123")
+ strict_eq(rv.query, "x=y")
+ strict_eq(rv.fragment, "frag")
- rv = urls.url_parse(u'http://\N{SNOWMAN}.com/')
- strict_eq(rv.host, u'\N{SNOWMAN}.com')
- strict_eq(rv.ascii_host, 'xn--n3h.com')
+ rv = urls.url_parse(u"http://\N{SNOWMAN}.com/")
+ strict_eq(rv.host, u"\N{SNOWMAN}.com")
+ strict_eq(rv.ascii_host, "xn--n3h.com")
def test_url_attributes_bytes():
- rv = urls.url_parse(b'http://foo%3a:bar%3a@[::1]:80/123?x=y#frag')
- strict_eq(rv.scheme, b'http')
- strict_eq(rv.auth, b'foo%3a:bar%3a')
- strict_eq(rv.username, u'foo:')
- strict_eq(rv.password, u'bar:')
- strict_eq(rv.raw_username, b'foo%3a')
- strict_eq(rv.raw_password, b'bar%3a')
- strict_eq(rv.host, b'::1')
+ rv = urls.url_parse(b"http://foo%3a:bar%3a@[::1]:80/123?x=y#frag")
+ strict_eq(rv.scheme, b"http")
+ strict_eq(rv.auth, b"foo%3a:bar%3a")
+ strict_eq(rv.username, u"foo:")
+ strict_eq(rv.password, u"bar:")
+ strict_eq(rv.raw_username, b"foo%3a")
+ strict_eq(rv.raw_password, b"bar%3a")
+ strict_eq(rv.host, b"::1")
assert rv.port == 80
- strict_eq(rv.path, b'/123')
- strict_eq(rv.query, b'x=y')
- strict_eq(rv.fragment, b'frag')
+ strict_eq(rv.path, b"/123")
+ strict_eq(rv.query, b"x=y")
+ strict_eq(rv.fragment, b"frag")
def test_url_joining():
- strict_eq(urls.url_join('/foo', '/bar'), '/bar')
- strict_eq(urls.url_join('http://example.com/foo', '/bar'),
- 'http://example.com/bar')
- strict_eq(urls.url_join('file:///tmp/', 'test.html'),
- 'file:///tmp/test.html')
- strict_eq(urls.url_join('file:///tmp/x', 'test.html'),
- 'file:///tmp/test.html')
- strict_eq(urls.url_join('file:///tmp/x', '../../../x.html'),
- 'file:///x.html')
+ strict_eq(urls.url_join("/foo", "/bar"), "/bar")
+ strict_eq(urls.url_join("http://example.com/foo", "/bar"), "http://example.com/bar")
+ strict_eq(urls.url_join("file:///tmp/", "test.html"), "file:///tmp/test.html")
+ strict_eq(urls.url_join("file:///tmp/x", "test.html"), "file:///tmp/test.html")
+ strict_eq(urls.url_join("file:///tmp/x", "../../../x.html"), "file:///x.html")
def test_partial_unencoded_decode():
- ref = u'foo=정상처리'.encode('euc-kr')
- x = urls.url_decode(ref, charset='euc-kr')
- strict_eq(x['foo'], u'정상처리')
+ ref = u"foo=정상처리".encode("euc-kr")
+ x = urls.url_decode(ref, charset="euc-kr")
+ strict_eq(x["foo"], u"정상처리")
def test_iri_to_uri_idempotence_ascii_only():
- uri = u'http://www.idempoten.ce'
+ uri = u"http://www.idempoten.ce"
uri = urls.iri_to_uri(uri)
assert urls.iri_to_uri(uri) == uri
def test_iri_to_uri_idempotence_non_ascii():
- uri = u'http://\N{SNOWMAN}/\N{SNOWMAN}'
+ uri = u"http://\N{SNOWMAN}/\N{SNOWMAN}"
uri = urls.iri_to_uri(uri)
assert urls.iri_to_uri(uri) == uri
def test_uri_to_iri_idempotence_ascii_only():
- uri = 'http://www.idempoten.ce'
+ uri = "http://www.idempoten.ce"
uri = urls.uri_to_iri(uri)
assert urls.uri_to_iri(uri) == uri
def test_uri_to_iri_idempotence_non_ascii():
- uri = 'http://xn--n3h/%E2%98%83'
+ uri = "http://xn--n3h/%E2%98%83"
uri = urls.uri_to_iri(uri)
assert urls.uri_to_iri(uri) == uri
def test_iri_to_uri_to_iri():
- iri = u'http://föö.com/'
+ iri = u"http://föö.com/"
uri = urls.iri_to_uri(iri)
assert urls.uri_to_iri(uri) == iri
def test_uri_to_iri_to_uri():
- uri = 'http://xn--f-rgao.com/%C3%9E'
+ uri = "http://xn--f-rgao.com/%C3%9E"
iri = urls.uri_to_iri(uri)
assert urls.iri_to_uri(iri) == uri
def test_uri_iri_normalization():
- uri = 'http://xn--f-rgao.com/%E2%98%90/fred?utf8=%E2%9C%93'
- iri = u'http://föñ.com/\N{BALLOT BOX}/fred?utf8=\u2713'
+ uri = "http://xn--f-rgao.com/%E2%98%90/fred?utf8=%E2%9C%93"
+ iri = u"http://föñ.com/\N{BALLOT BOX}/fred?utf8=\u2713"
tests = [
- u'http://föñ.com/\N{BALLOT BOX}/fred?utf8=\u2713',
- u'http://xn--f-rgao.com/\u2610/fred?utf8=\N{CHECK MARK}',
- b'http://xn--f-rgao.com/%E2%98%90/fred?utf8=%E2%9C%93',
- u'http://xn--f-rgao.com/%E2%98%90/fred?utf8=%E2%9C%93',
- u'http://föñ.com/\u2610/fred?utf8=%E2%9C%93',
- b'http://xn--f-rgao.com/\xe2\x98\x90/fred?utf8=\xe2\x9c\x93',
+ u"http://föñ.com/\N{BALLOT BOX}/fred?utf8=\u2713",
+ u"http://xn--f-rgao.com/\u2610/fred?utf8=\N{CHECK MARK}",
+ b"http://xn--f-rgao.com/%E2%98%90/fred?utf8=%E2%9C%93",
+ u"http://xn--f-rgao.com/%E2%98%90/fred?utf8=%E2%9C%93",
+ u"http://föñ.com/\u2610/fred?utf8=%E2%9C%93",
+ b"http://xn--f-rgao.com/\xe2\x98\x90/fred?utf8=\xe2\x9c\x93",
]
for test in tests:
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 0435a251..f288edea 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -8,43 +8,46 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-import pytest
-
-from datetime import datetime
import inspect
+from datetime import datetime
+
+import pytest
from werkzeug import utils
+from werkzeug._compat import text_type
from werkzeug.datastructures import Headers
-from werkzeug.http import parse_date, http_date
-from werkzeug.wrappers import BaseResponse
+from werkzeug.http import http_date
+from werkzeug.http import parse_date
from werkzeug.test import Client
-from werkzeug._compat import text_type
+from werkzeug.wrappers import BaseResponse
def test_redirect():
- resp = utils.redirect(u'/füübär')
- assert b'/f%C3%BC%C3%BCb%C3%A4r' in resp.get_data()
- assert resp.headers['Location'] == '/f%C3%BC%C3%BCb%C3%A4r'
+ resp = utils.redirect(u"/füübär")
+ assert b"/f%C3%BC%C3%BCb%C3%A4r" in resp.get_data()
+ assert resp.headers["Location"] == "/f%C3%BC%C3%BCb%C3%A4r"
assert resp.status_code == 302
- resp = utils.redirect(u'http://☃.net/', 307)
- assert b'http://xn--n3h.net/' in resp.get_data()
- assert resp.headers['Location'] == 'http://xn--n3h.net/'
+ resp = utils.redirect(u"http://☃.net/", 307)
+ assert b"http://xn--n3h.net/" in resp.get_data()
+ assert resp.headers["Location"] == "http://xn--n3h.net/"
assert resp.status_code == 307
- resp = utils.redirect('http://example.com/', 305)
- assert resp.headers['Location'] == 'http://example.com/'
+ resp = utils.redirect("http://example.com/", 305)
+ assert resp.headers["Location"] == "http://example.com/"
assert resp.status_code == 305
def test_redirect_xss():
location = 'http://example.com/?xss="><script>alert(1)</script>'
resp = utils.redirect(location)
- assert b'<script>alert(1)</script>' not in resp.get_data()
+ assert b"<script>alert(1)</script>" not in resp.get_data()
location = 'http://example.com/?xss="onmouseover="alert(1)'
resp = utils.redirect(location)
- assert b'href="http://example.com/?xss="onmouseover="alert(1)"' not in resp.get_data()
+ assert (
+ b'href="http://example.com/?xss="onmouseover="alert(1)"' not in resp.get_data()
+ )
def test_redirect_with_custom_response_class():
@@ -55,17 +58,17 @@ def test_redirect_with_custom_response_class():
resp = utils.redirect(location, Response=MyResponse)
assert isinstance(resp, MyResponse)
- assert resp.headers['Location'] == location
+ assert resp.headers["Location"] == location
def test_cached_property():
foo = []
class A(object):
-
def prop(self):
foo.append(42)
return 42
+
prop = utils.cached_property(prop)
a = A()
@@ -77,11 +80,11 @@ def test_cached_property():
foo = []
class A(object):
-
def _prop(self):
foo.append(42)
return 42
- prop = utils.cached_property(_prop, name='prop')
+
+ prop = utils.cached_property(_prop, name="prop")
del _prop
a = A()
@@ -93,198 +96,222 @@ def test_cached_property():
def test_can_set_cached_property():
class A(object):
-
@utils.cached_property
def _prop(self):
- return 'cached_property return value'
+ return "cached_property return value"
a = A()
- a._prop = 'value'
- assert a._prop == 'value'
+ a._prop = "value"
+ assert a._prop == "value"
def test_inspect_treats_cached_property_as_property():
class A(object):
-
@utils.cached_property
def _prop(self):
- return 'cached_property return value'
+ return "cached_property return value"
attrs = inspect.classify_class_attrs(A)
for attr in attrs:
- if attr.name == '_prop':
+ if attr.name == "_prop":
break
- assert attr.kind == 'property'
+ assert attr.kind == "property"
def test_environ_property():
class A(object):
- environ = {'string': 'abc', 'number': '42'}
-
- string = utils.environ_property('string')
- missing = utils.environ_property('missing', 'spam')
- read_only = utils.environ_property('number')
- number = utils.environ_property('number', load_func=int)
- broken_number = utils.environ_property('broken_number', load_func=int)
- date = utils.environ_property('date', None, parse_date, http_date,
- read_only=False)
- foo = utils.environ_property('foo')
+ environ = {"string": "abc", "number": "42"}
+
+ string = utils.environ_property("string")
+ missing = utils.environ_property("missing", "spam")
+ read_only = utils.environ_property("number")
+ number = utils.environ_property("number", load_func=int)
+ broken_number = utils.environ_property("broken_number", load_func=int)
+ date = utils.environ_property(
+ "date", None, parse_date, http_date, read_only=False
+ )
+ foo = utils.environ_property("foo")
a = A()
- assert a.string == 'abc'
- assert a.missing == 'spam'
+ assert a.string == "abc"
+ assert a.missing == "spam"
def test_assign():
- a.read_only = 'something'
+ a.read_only = "something"
+
pytest.raises(AttributeError, test_assign)
assert a.number == 42
assert a.broken_number is None
assert a.date is None
a.date = datetime(2008, 1, 22, 10, 0, 0, 0)
- assert a.environ['date'] == 'Tue, 22 Jan 2008 10:00:00 GMT'
+ assert a.environ["date"] == "Tue, 22 Jan 2008 10:00:00 GMT"
def test_escape():
class Foo(str):
-
def __html__(self):
return text_type(self)
- assert utils.escape(None) == ''
- assert utils.escape(42) == '42'
- assert utils.escape('<>') == '&lt;&gt;'
- assert utils.escape('"foo"') == '&quot;foo&quot;'
- assert utils.escape(Foo('<foo>')) == '<foo>'
+
+ assert utils.escape(None) == ""
+ assert utils.escape(42) == "42"
+ assert utils.escape("<>") == "&lt;&gt;"
+ assert utils.escape('"foo"') == "&quot;foo&quot;"
+ assert utils.escape(Foo("<foo>")) == "<foo>"
def test_unescape():
- assert utils.unescape('&lt;&auml;&gt;') == u'<ä>'
+ assert utils.unescape("&lt;&auml;&gt;") == u"<ä>"
def test_import_string():
from datetime import date
from werkzeug.debug import DebuggedApplication
- assert utils.import_string('datetime.date') is date
- assert utils.import_string(u'datetime.date') is date
- assert utils.import_string('datetime:date') is date
- assert utils.import_string('XXXXXXXXXXXX', True) is None
- assert utils.import_string('datetime.XXXXXXXXXXXX', True) is None
- assert utils.import_string(u'werkzeug.debug.DebuggedApplication') is DebuggedApplication
- pytest.raises(ImportError, utils.import_string, 'XXXXXXXXXXXXXXXX')
- pytest.raises(ImportError, utils.import_string, 'datetime.XXXXXXXXXX')
+
+ assert utils.import_string("datetime.date") is date
+ assert utils.import_string(u"datetime.date") is date
+ assert utils.import_string("datetime:date") is date
+ assert utils.import_string("XXXXXXXXXXXX", True) is None
+ assert utils.import_string("datetime.XXXXXXXXXXXX", True) is None
+ assert (
+ utils.import_string(u"werkzeug.debug.DebuggedApplication")
+ is DebuggedApplication
+ )
+ pytest.raises(ImportError, utils.import_string, "XXXXXXXXXXXXXXXX")
+ pytest.raises(ImportError, utils.import_string, "datetime.XXXXXXXXXX")
def test_import_string_provides_traceback(tmpdir, monkeypatch):
monkeypatch.syspath_prepend(str(tmpdir))
# Couple of packages
- dir_a = tmpdir.mkdir('a')
- dir_b = tmpdir.mkdir('b')
+ dir_a = tmpdir.mkdir("a")
+ dir_b = tmpdir.mkdir("b")
# Totally packages, I promise
- dir_a.join('__init__.py').write('')
- dir_b.join('__init__.py').write('')
+ dir_a.join("__init__.py").write("")
+ dir_b.join("__init__.py").write("")
# 'aa.a' that depends on 'bb.b', which in turn has a broken import
- dir_a.join('aa.py').write('from b import bb')
- dir_b.join('bb.py').write('from os import a_typo')
+ dir_a.join("aa.py").write("from b import bb")
+ dir_b.join("bb.py").write("from os import a_typo")
# Do we get all the useful information in the traceback?
with pytest.raises(ImportError) as baz_exc:
- utils.import_string('a.aa')
- traceback = ''.join((str(line) for line in baz_exc.traceback))
- assert 'bb.py\':1' in traceback # a bit different than typical python tb
- assert 'from os import a_typo' in traceback
+ utils.import_string("a.aa")
+ traceback = "".join((str(line) for line in baz_exc.traceback))
+ assert "bb.py':1" in traceback # a bit different than typical python tb
+ assert "from os import a_typo" in traceback
def test_import_string_attribute_error(tmpdir, monkeypatch):
monkeypatch.syspath_prepend(str(tmpdir))
- tmpdir.join('foo_test.py').write('from bar_test import value')
- tmpdir.join('bar_test.py').write('raise AttributeError("screw you!")')
+ tmpdir.join("foo_test.py").write("from bar_test import value")
+ tmpdir.join("bar_test.py").write('raise AttributeError("screw you!")')
with pytest.raises(AttributeError) as foo_exc:
- utils.import_string('foo_test')
- assert 'screw you!' in str(foo_exc)
+ utils.import_string("foo_test")
+ assert "screw you!" in str(foo_exc)
with pytest.raises(AttributeError) as bar_exc:
- utils.import_string('bar_test')
- assert 'screw you!' in str(bar_exc)
+ utils.import_string("bar_test")
+ assert "screw you!" in str(bar_exc)
def test_find_modules():
- assert list(utils.find_modules('werkzeug.debug')) == [
- 'werkzeug.debug.console', 'werkzeug.debug.repr',
- 'werkzeug.debug.tbtools'
+ assert list(utils.find_modules("werkzeug.debug")) == [
+ "werkzeug.debug.console",
+ "werkzeug.debug.repr",
+ "werkzeug.debug.tbtools",
]
def test_html_builder():
html = utils.html
xhtml = utils.xhtml
- assert html.p('Hello World') == '<p>Hello World</p>'
- assert html.a('Test', href='#') == '<a href="#">Test</a>'
- assert html.br() == '<br>'
- assert xhtml.br() == '<br />'
- assert html.img(src='foo') == '<img src="foo">'
- assert xhtml.img(src='foo') == '<img src="foo" />'
- assert html.html(html.head(
- html.title('foo'),
- html.script(type='text/javascript')
- )) == (
+ assert html.p("Hello World") == "<p>Hello World</p>"
+ assert html.a("Test", href="#") == '<a href="#">Test</a>'
+ assert html.br() == "<br>"
+ assert xhtml.br() == "<br />"
+ assert html.img(src="foo") == '<img src="foo">'
+ assert xhtml.img(src="foo") == '<img src="foo" />'
+ assert html.html(
+ html.head(html.title("foo"), html.script(type="text/javascript"))
+ ) == (
'<html><head><title>foo</title><script type="text/javascript">'
- '</script></head></html>'
+ "</script></head></html>"
)
- assert html('<foo>') == '&lt;foo&gt;'
- assert html.input(disabled=True) == '<input disabled>'
+ assert html("<foo>") == "&lt;foo&gt;"
+ assert html.input(disabled=True) == "<input disabled>"
assert xhtml.input(disabled=True) == '<input disabled="disabled" />'
- assert html.input(disabled='') == '<input>'
- assert xhtml.input(disabled='') == '<input />'
- assert html.input(disabled=None) == '<input>'
- assert xhtml.input(disabled=None) == '<input />'
- assert html.script('alert("Hello World");') == \
- '<script>alert("Hello World");</script>'
- assert xhtml.script('alert("Hello World");') == \
- '<script>/*<![CDATA[*/alert("Hello World");/*]]>*/</script>'
+ assert html.input(disabled="") == "<input>"
+ assert xhtml.input(disabled="") == "<input />"
+ assert html.input(disabled=None) == "<input>"
+ assert xhtml.input(disabled=None) == "<input />"
+ assert (
+ html.script('alert("Hello World");') == '<script>alert("Hello World");</script>'
+ )
+ assert (
+ xhtml.script('alert("Hello World");')
+ == '<script>/*<![CDATA[*/alert("Hello World");/*]]>*/</script>'
+ )
def test_validate_arguments():
- take_none = lambda: None
- take_two = lambda a, b: None
- take_two_one_default = lambda a, b=0: None
+ def take_none():
+ pass
+
+ def take_two(a, b):
+ pass
- assert utils.validate_arguments(take_two, (1, 2,), {}) == ((1, 2), {})
- assert utils.validate_arguments(take_two, (1,), {'b': 2}) == ((1, 2), {})
+ def take_two_one_default(a, b=0):
+ pass
+
+ assert utils.validate_arguments(take_two, (1, 2), {}) == ((1, 2), {})
+ assert utils.validate_arguments(take_two, (1,), {"b": 2}) == ((1, 2), {})
assert utils.validate_arguments(take_two_one_default, (1,), {}) == ((1, 0), {})
assert utils.validate_arguments(take_two_one_default, (1, 2), {}) == ((1, 2), {})
- pytest.raises(utils.ArgumentValidationError,
- utils.validate_arguments, take_two, (), {})
+ pytest.raises(
+ utils.ArgumentValidationError, utils.validate_arguments, take_two, (), {}
+ )
- assert utils.validate_arguments(take_none, (1, 2,), {'c': 3}) == ((), {})
- pytest.raises(utils.ArgumentValidationError,
- utils.validate_arguments, take_none, (1,), {}, drop_extra=False)
- pytest.raises(utils.ArgumentValidationError,
- utils.validate_arguments, take_none, (), {'a': 1}, drop_extra=False)
+ assert utils.validate_arguments(take_none, (1, 2), {"c": 3}) == ((), {})
+ pytest.raises(
+ utils.ArgumentValidationError,
+ utils.validate_arguments,
+ take_none,
+ (1,),
+ {},
+ drop_extra=False,
+ )
+ pytest.raises(
+ utils.ArgumentValidationError,
+ utils.validate_arguments,
+ take_none,
+ (),
+ {"a": 1},
+ drop_extra=False,
+ )
def test_header_set_duplication_bug():
- headers = Headers([
- ('Content-Type', 'text/html'),
- ('Foo', 'bar'),
- ('Blub', 'blah')
- ])
- headers['blub'] = 'hehe'
- headers['blafasel'] = 'humm'
- assert headers == Headers([
- ('Content-Type', 'text/html'),
- ('Foo', 'bar'),
- ('blub', 'hehe'),
- ('blafasel', 'humm')
- ])
+ headers = Headers([("Content-Type", "text/html"), ("Foo", "bar"), ("Blub", "blah")])
+ headers["blub"] = "hehe"
+ headers["blafasel"] = "humm"
+ assert headers == Headers(
+ [
+ ("Content-Type", "text/html"),
+ ("Foo", "bar"),
+ ("blub", "hehe"),
+ ("blafasel", "humm"),
+ ]
+ )
def test_append_slash_redirect():
def app(env, sr):
return utils.append_slash_redirect(env)(env, sr)
+
client = Client(app, BaseResponse)
- response = client.get('foo', base_url='http://example.org/app')
+ response = client.get("foo", base_url="http://example.org/app")
assert response.status_code == 301
- assert response.headers['Location'] == 'http://example.org/app/foo/'
+ assert response.headers["Location"] == "http://example.org/app/foo/"
def test_cached_property_doc():
@@ -292,15 +319,18 @@ def test_cached_property_doc():
def foo():
"""testing"""
return 42
- assert foo.__doc__ == 'testing'
- assert foo.__name__ == 'foo'
+
+ assert foo.__doc__ == "testing"
+ assert foo.__name__ == "foo"
assert foo.__module__ == __name__
def test_secure_filename():
- assert utils.secure_filename('My cool movie.mov') == 'My_cool_movie.mov'
- assert utils.secure_filename('../../../etc/passwd') == 'etc_passwd'
- assert utils.secure_filename(u'i contain cool \xfcml\xe4uts.txt') == \
- 'i_contain_cool_umlauts.txt'
- assert utils.secure_filename('__filename__') == 'filename'
- assert utils.secure_filename('foo$&^*)bar') == 'foobar'
+ assert utils.secure_filename("My cool movie.mov") == "My_cool_movie.mov"
+ assert utils.secure_filename("../../../etc/passwd") == "etc_passwd"
+ assert (
+ utils.secure_filename(u"i contain cool \xfcml\xe4uts.txt")
+ == "i_contain_cool_umlauts.txt"
+ )
+ assert utils.secure_filename("__filename__") == "filename"
+ assert utils.secure_filename("foo$&^*)bar") == "foobar"
diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py
index 203bea6c..511d66de 100644
--- a/tests/test_wrappers.py
+++ b/tests/test_wrappers.py
@@ -11,33 +11,41 @@
import contextlib
import json
import os
-
-import pytest
-
import pickle
+from datetime import datetime
+from datetime import timedelta
from io import BytesIO
-from datetime import datetime, timedelta
-from werkzeug._compat import iteritems
-from tests import strict_eq
+import pytest
+from . import strict_eq
from werkzeug import wrappers
-from werkzeug.exceptions import SecurityError, RequestedRangeNotSatisfiable, \
- BadRequest
+from werkzeug._compat import implements_iterator
+from werkzeug._compat import iteritems
+from werkzeug._compat import text_type
+from werkzeug.datastructures import Accept
+from werkzeug.datastructures import CharsetAccept
+from werkzeug.datastructures import CombinedMultiDict
+from werkzeug.datastructures import Headers
+from werkzeug.datastructures import ImmutableList
+from werkzeug.datastructures import ImmutableOrderedMultiDict
+from werkzeug.datastructures import ImmutableTypeConversionDict
+from werkzeug.datastructures import LanguageAccept
+from werkzeug.datastructures import MIMEAccept
+from werkzeug.datastructures import MultiDict
+from werkzeug.exceptions import BadRequest
+from werkzeug.exceptions import RequestedRangeNotSatisfiable
+from werkzeug.exceptions import SecurityError
from werkzeug.http import generate_etag
+from werkzeug.test import Client
+from werkzeug.test import create_environ
+from werkzeug.test import run_wsgi_app
from werkzeug.wrappers.json import JSONMixin
-from werkzeug.wsgi import LimitedStream, wrap_file
-from werkzeug.datastructures import (
- MultiDict, ImmutableOrderedMultiDict,
- ImmutableList, ImmutableTypeConversionDict, CharsetAccept,
- MIMEAccept, LanguageAccept, Accept, CombinedMultiDict, Headers,
-)
-from werkzeug.test import Client, create_environ, run_wsgi_app
-from werkzeug._compat import implements_iterator, text_type
+from werkzeug.wsgi import LimitedStream
+from werkzeug.wsgi import wrap_file
class RequestTestResponse(wrappers.BaseResponse):
-
"""Subclass of the normal response class we use to test response
and base classes. Has some methods to test if things in the
response match.
@@ -53,16 +61,20 @@ class RequestTestResponse(wrappers.BaseResponse):
def request_demo_app(environ, start_response):
request = wrappers.BaseRequest(environ)
- assert 'werkzeug.request' in environ
- start_response('200 OK', [('Content-Type', 'text/plain')])
- return [pickle.dumps({
- 'args': request.args,
- 'args_as_list': list(request.args.lists()),
- 'form': request.form,
- 'form_as_list': list(request.form.lists()),
- 'environ': prepare_environ_pickle(request.environ),
- 'data': request.get_data()
- })]
+ assert "werkzeug.request" in environ
+ start_response("200 OK", [("Content-Type", "text/plain")])
+ return [
+ pickle.dumps(
+ {
+ "args": request.args,
+ "args_as_list": list(request.args.lists()),
+ "form": request.form,
+ "form_as_list": list(request.form.lists()),
+ "environ": prepare_environ_pickle(request.environ),
+ "data": request.get_data(),
+ }
+ )
+ ]
def prepare_environ_pickle(environ):
@@ -77,127 +89,132 @@ def prepare_environ_pickle(environ):
def assert_environ(environ, method):
- strict_eq(environ['REQUEST_METHOD'], method)
- strict_eq(environ['PATH_INFO'], '/')
- strict_eq(environ['SCRIPT_NAME'], '')
- strict_eq(environ['SERVER_NAME'], 'localhost')
- strict_eq(environ['wsgi.version'], (1, 0))
- strict_eq(environ['wsgi.url_scheme'], 'http')
+ strict_eq(environ["REQUEST_METHOD"], method)
+ strict_eq(environ["PATH_INFO"], "/")
+ strict_eq(environ["SCRIPT_NAME"], "")
+ strict_eq(environ["SERVER_NAME"], "localhost")
+ strict_eq(environ["wsgi.version"], (1, 0))
+ strict_eq(environ["wsgi.url_scheme"], "http")
def test_base_request():
client = Client(request_demo_app, RequestTestResponse)
# get requests
- response = client.get('/?foo=bar&foo=hehe')
- strict_eq(response['args'], MultiDict([('foo', u'bar'), ('foo', u'hehe')]))
- strict_eq(response['args_as_list'], [('foo', [u'bar', u'hehe'])])
- strict_eq(response['form'], MultiDict())
- strict_eq(response['form_as_list'], [])
- strict_eq(response['data'], b'')
- assert_environ(response['environ'], 'GET')
+ response = client.get("/?foo=bar&foo=hehe")
+ strict_eq(response["args"], MultiDict([("foo", u"bar"), ("foo", u"hehe")]))
+ strict_eq(response["args_as_list"], [("foo", [u"bar", u"hehe"])])
+ strict_eq(response["form"], MultiDict())
+ strict_eq(response["form_as_list"], [])
+ strict_eq(response["data"], b"")
+ assert_environ(response["environ"], "GET")
# post requests with form data
- response = client.post('/?blub=blah', data='foo=blub+hehe&blah=42',
- content_type='application/x-www-form-urlencoded')
- strict_eq(response['args'], MultiDict([('blub', u'blah')]))
- strict_eq(response['args_as_list'], [('blub', [u'blah'])])
- strict_eq(response['form'], MultiDict([('foo', u'blub hehe'), ('blah', u'42')]))
- strict_eq(response['data'], b'')
+ response = client.post(
+ "/?blub=blah",
+ data="foo=blub+hehe&blah=42",
+ content_type="application/x-www-form-urlencoded",
+ )
+ strict_eq(response["args"], MultiDict([("blub", u"blah")]))
+ strict_eq(response["args_as_list"], [("blub", [u"blah"])])
+ strict_eq(response["form"], MultiDict([("foo", u"blub hehe"), ("blah", u"42")]))
+ strict_eq(response["data"], b"")
# currently we do not guarantee that the values are ordered correctly
# for post data.
# strict_eq(response['form_as_list'], [('foo', ['blub hehe']), ('blah', ['42'])])
- assert_environ(response['environ'], 'POST')
+ assert_environ(response["environ"], "POST")
# patch requests with form data
- response = client.patch('/?blub=blah', data='foo=blub+hehe&blah=42',
- content_type='application/x-www-form-urlencoded')
- strict_eq(response['args'], MultiDict([('blub', u'blah')]))
- strict_eq(response['args_as_list'], [('blub', [u'blah'])])
- strict_eq(response['form'],
- MultiDict([('foo', u'blub hehe'), ('blah', u'42')]))
- strict_eq(response['data'], b'')
- assert_environ(response['environ'], 'PATCH')
+ response = client.patch(
+ "/?blub=blah",
+ data="foo=blub+hehe&blah=42",
+ content_type="application/x-www-form-urlencoded",
+ )
+ strict_eq(response["args"], MultiDict([("blub", u"blah")]))
+ strict_eq(response["args_as_list"], [("blub", [u"blah"])])
+ strict_eq(response["form"], MultiDict([("foo", u"blub hehe"), ("blah", u"42")]))
+ strict_eq(response["data"], b"")
+ assert_environ(response["environ"], "PATCH")
# post requests with json data
json = b'{"foo": "bar", "blub": "blah"}'
- response = client.post('/?a=b', data=json, content_type='application/json')
- strict_eq(response['data'], json)
- strict_eq(response['args'], MultiDict([('a', u'b')]))
- strict_eq(response['form'], MultiDict())
+ response = client.post("/?a=b", data=json, content_type="application/json")
+ strict_eq(response["data"], json)
+ strict_eq(response["args"], MultiDict([("a", u"b")]))
+ strict_eq(response["form"], MultiDict())
def test_query_string_is_bytes():
- req = wrappers.Request.from_values(u'/?foo=%2f')
- strict_eq(req.query_string, b'foo=%2f')
+ req = wrappers.Request.from_values(u"/?foo=%2f")
+ strict_eq(req.query_string, b"foo=%2f")
def test_request_repr():
- req = wrappers.Request.from_values('/foobar')
+ req = wrappers.Request.from_values("/foobar")
assert "<Request 'http://localhost/foobar' [GET]>" == repr(req)
# test with non-ascii characters
- req = wrappers.Request.from_values('/привет')
+ req = wrappers.Request.from_values("/привет")
assert "<Request 'http://localhost/привет' [GET]>" == repr(req)
# test with unicode type for python 2
- req = wrappers.Request.from_values(u'/привет')
+ req = wrappers.Request.from_values(u"/привет")
assert "<Request 'http://localhost/привет' [GET]>" == repr(req)
def test_access_route():
- req = wrappers.Request.from_values(headers={
- 'X-Forwarded-For': '192.168.1.2, 192.168.1.1'
- })
- req.environ['REMOTE_ADDR'] = '192.168.1.3'
- assert req.access_route == ['192.168.1.2', '192.168.1.1']
- strict_eq(req.remote_addr, '192.168.1.3')
+ req = wrappers.Request.from_values(
+ headers={"X-Forwarded-For": "192.168.1.2, 192.168.1.1"}
+ )
+ req.environ["REMOTE_ADDR"] = "192.168.1.3"
+ assert req.access_route == ["192.168.1.2", "192.168.1.1"]
+ strict_eq(req.remote_addr, "192.168.1.3")
req = wrappers.Request.from_values()
- req.environ['REMOTE_ADDR'] = '192.168.1.3'
- strict_eq(list(req.access_route), ['192.168.1.3'])
+ req.environ["REMOTE_ADDR"] = "192.168.1.3"
+ strict_eq(list(req.access_route), ["192.168.1.3"])
def test_url_request_descriptors():
- req = wrappers.Request.from_values('/bar?foo=baz', 'http://example.com/test')
- strict_eq(req.path, u'/bar')
- strict_eq(req.full_path, u'/bar?foo=baz')
- strict_eq(req.script_root, u'/test')
- strict_eq(req.url, u'http://example.com/test/bar?foo=baz')
- strict_eq(req.base_url, u'http://example.com/test/bar')
- strict_eq(req.url_root, u'http://example.com/test/')
- strict_eq(req.host_url, u'http://example.com/')
- strict_eq(req.host, 'example.com')
- strict_eq(req.scheme, 'http')
+ req = wrappers.Request.from_values("/bar?foo=baz", "http://example.com/test")
+ strict_eq(req.path, u"/bar")
+ strict_eq(req.full_path, u"/bar?foo=baz")
+ strict_eq(req.script_root, u"/test")
+ strict_eq(req.url, u"http://example.com/test/bar?foo=baz")
+ strict_eq(req.base_url, u"http://example.com/test/bar")
+ strict_eq(req.url_root, u"http://example.com/test/")
+ strict_eq(req.host_url, u"http://example.com/")
+ strict_eq(req.host, "example.com")
+ strict_eq(req.scheme, "http")
- req = wrappers.Request.from_values('/bar?foo=baz', 'https://example.com/test')
- strict_eq(req.scheme, 'https')
+ req = wrappers.Request.from_values("/bar?foo=baz", "https://example.com/test")
+ strict_eq(req.scheme, "https")
def test_url_request_descriptors_query_quoting():
- next = 'http%3A%2F%2Fwww.example.com%2F%3Fnext%3D%2Fbaz%23my%3Dhash'
- req = wrappers.Request.from_values('/bar?next=' + next, 'http://example.com/')
- assert req.path == u'/bar'
- strict_eq(req.full_path, u'/bar?next=' + next)
- strict_eq(req.url, u'http://example.com/bar?next=' + next)
+ next = "http%3A%2F%2Fwww.example.com%2F%3Fnext%3D%2Fbaz%23my%3Dhash"
+ req = wrappers.Request.from_values("/bar?next=" + next, "http://example.com/")
+ assert req.path == u"/bar"
+ strict_eq(req.full_path, u"/bar?next=" + next)
+ strict_eq(req.url, u"http://example.com/bar?next=" + next)
def test_url_request_descriptors_hosts():
- req = wrappers.Request.from_values('/bar?foo=baz', 'http://example.com/test')
- req.trusted_hosts = ['example.com']
- strict_eq(req.path, u'/bar')
- strict_eq(req.full_path, u'/bar?foo=baz')
- strict_eq(req.script_root, u'/test')
- strict_eq(req.url, u'http://example.com/test/bar?foo=baz')
- strict_eq(req.base_url, u'http://example.com/test/bar')
- strict_eq(req.url_root, u'http://example.com/test/')
- strict_eq(req.host_url, u'http://example.com/')
- strict_eq(req.host, 'example.com')
- strict_eq(req.scheme, 'http')
-
- req = wrappers.Request.from_values('/bar?foo=baz', 'https://example.com/test')
- strict_eq(req.scheme, 'https')
-
- req = wrappers.Request.from_values('/bar?foo=baz', 'http://example.com/test')
- req.trusted_hosts = ['example.org']
+ req = wrappers.Request.from_values("/bar?foo=baz", "http://example.com/test")
+ req.trusted_hosts = ["example.com"]
+ strict_eq(req.path, u"/bar")
+ strict_eq(req.full_path, u"/bar?foo=baz")
+ strict_eq(req.script_root, u"/test")
+ strict_eq(req.url, u"http://example.com/test/bar?foo=baz")
+ strict_eq(req.base_url, u"http://example.com/test/bar")
+ strict_eq(req.url_root, u"http://example.com/test/")
+ strict_eq(req.host_url, u"http://example.com/")
+ strict_eq(req.host, "example.com")
+ strict_eq(req.scheme, "http")
+
+ req = wrappers.Request.from_values("/bar?foo=baz", "https://example.com/test")
+ strict_eq(req.scheme, "https")
+
+ req = wrappers.Request.from_values("/bar?foo=baz", "http://example.com/test")
+ req.trusted_hosts = ["example.org"]
pytest.raises(SecurityError, lambda: req.url)
pytest.raises(SecurityError, lambda: req.base_url)
pytest.raises(SecurityError, lambda: req.url_root)
@@ -206,89 +223,106 @@ def test_url_request_descriptors_hosts():
def test_authorization_mixin():
- request = wrappers.Request.from_values(headers={
- 'Authorization': 'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=='
- })
+ request = wrappers.Request.from_values(
+ headers={"Authorization": "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="}
+ )
a = request.authorization
- strict_eq(a.type, 'basic')
- strict_eq(a.username, u'Aladdin')
- strict_eq(a.password, u'open sesame')
+ strict_eq(a.type, "basic")
+ strict_eq(a.username, u"Aladdin")
+ strict_eq(a.password, u"open sesame")
def test_authorization_with_unicode():
- request = wrappers.Request.from_values(headers={
- 'Authorization': 'Basic 0YDRg9GB0YHQutC40IE60JHRg9C60LLRiw=='
- })
+ request = wrappers.Request.from_values(
+ headers={"Authorization": "Basic 0YDRg9GB0YHQutC40IE60JHRg9C60LLRiw=="}
+ )
a = request.authorization
- strict_eq(a.type, 'basic')
- strict_eq(a.username, u'русскиЁ')
- strict_eq(a.password, u'Буквы')
+ strict_eq(a.type, "basic")
+ strict_eq(a.username, u"русскиЁ")
+ strict_eq(a.password, u"Буквы")
def test_stream_only_mixing():
request = wrappers.PlainRequest.from_values(
- data=b'foo=blub+hehe',
- content_type='application/x-www-form-urlencoded'
+ data=b"foo=blub+hehe", content_type="application/x-www-form-urlencoded"
)
assert list(request.files.items()) == []
assert list(request.form.items()) == []
pytest.raises(AttributeError, lambda: request.data)
- strict_eq(request.stream.read(), b'foo=blub+hehe')
+ strict_eq(request.stream.read(), b"foo=blub+hehe")
def test_request_application():
@wrappers.Request.application
def application(request):
- return wrappers.Response('Hello World!')
+ return wrappers.Response("Hello World!")
@wrappers.Request.application
def failing_application(request):
raise BadRequest()
resp = wrappers.Response.from_app(application, create_environ())
- assert resp.data == b'Hello World!'
+ assert resp.data == b"Hello World!"
assert resp.status_code == 200
resp = wrappers.Response.from_app(failing_application, create_environ())
- assert b'Bad Request' in resp.data
+ assert b"Bad Request" in resp.data
assert resp.status_code == 400
def test_base_response():
# unicode
- response = wrappers.BaseResponse(u'öäü')
- strict_eq(response.get_data(), u'öäü'.encode('utf-8'))
+ response = wrappers.BaseResponse(u"öäü")
+ strict_eq(response.get_data(), u"öäü".encode("utf-8"))
# writing
- response = wrappers.Response('foo')
- response.stream.write('bar')
- strict_eq(response.get_data(), b'foobar')
+ response = wrappers.Response("foo")
+ response.stream.write("bar")
+ strict_eq(response.get_data(), b"foobar")
# set cookie
response = wrappers.BaseResponse()
- response.set_cookie('foo', value='bar', max_age=60, expires=0,
- path='/blub', domain='example.org', samesite='Strict')
- strict_eq(response.headers.to_wsgi_list(), [
- ('Content-Type', 'text/plain; charset=utf-8'),
- ('Set-Cookie', 'foo=bar; Domain=example.org; Expires=Thu, '
- '01-Jan-1970 00:00:00 GMT; Max-Age=60; Path=/blub; '
- 'SameSite=Strict')
- ])
+ response.set_cookie(
+ "foo",
+ value="bar",
+ max_age=60,
+ expires=0,
+ path="/blub",
+ domain="example.org",
+ samesite="Strict",
+ )
+ strict_eq(
+ response.headers.to_wsgi_list(),
+ [
+ ("Content-Type", "text/plain; charset=utf-8"),
+ (
+ "Set-Cookie",
+ "foo=bar; Domain=example.org; Expires=Thu, "
+ "01-Jan-1970 00:00:00 GMT; Max-Age=60; Path=/blub; "
+ "SameSite=Strict",
+ ),
+ ],
+ )
# delete cookie
response = wrappers.BaseResponse()
- response.delete_cookie('foo')
- strict_eq(response.headers.to_wsgi_list(), [
- ('Content-Type', 'text/plain; charset=utf-8'),
- ('Set-Cookie', 'foo=; Expires=Thu, 01-Jan-1970 00:00:00 GMT; Max-Age=0; Path=/')
- ])
+ response.delete_cookie("foo")
+ strict_eq(
+ response.headers.to_wsgi_list(),
+ [
+ ("Content-Type", "text/plain; charset=utf-8"),
+ (
+ "Set-Cookie",
+ "foo=; Expires=Thu, 01-Jan-1970 00:00:00 GMT; Max-Age=0; Path=/",
+ ),
+ ],
+ )
# close call forwarding
closed = []
@implements_iterator
class Iterable(object):
-
def __next__(self):
raise StopIteration()
@@ -297,13 +331,12 @@ def test_base_response():
def close(self):
closed.append(True)
+
response = wrappers.BaseResponse(Iterable())
response.call_on_close(lambda: closed.append(True))
- app_iter, status, headers = run_wsgi_app(response,
- create_environ(),
- buffered=True)
- strict_eq(status, '200 OK')
- strict_eq(''.join(app_iter), '')
+ app_iter, status, headers = run_wsgi_app(response, create_environ(), buffered=True)
+ strict_eq(status, "200 OK")
+ strict_eq("".join(app_iter), "")
strict_eq(len(closed), 2)
# with statement
@@ -317,36 +350,36 @@ def test_base_response():
def test_response_status_codes():
response = wrappers.BaseResponse()
response.status_code = 404
- strict_eq(response.status, '404 NOT FOUND')
- response.status = '200 OK'
+ strict_eq(response.status, "404 NOT FOUND")
+ response.status = "200 OK"
strict_eq(response.status_code, 200)
- response.status = '999 WTF'
+ response.status = "999 WTF"
strict_eq(response.status_code, 999)
response.status_code = 588
strict_eq(response.status_code, 588)
- strict_eq(response.status, '588 UNKNOWN')
- response.status = 'wtf'
+ strict_eq(response.status, "588 UNKNOWN")
+ response.status = "wtf"
strict_eq(response.status_code, 0)
- strict_eq(response.status, '0 wtf')
+ strict_eq(response.status, "0 wtf")
# invalid status codes
with pytest.raises(ValueError) as empty_string_error:
- wrappers.BaseResponse(None, '')
- assert 'Empty status argument' in str(empty_string_error)
+ wrappers.BaseResponse(None, "")
+ assert "Empty status argument" in str(empty_string_error)
with pytest.raises(TypeError) as invalid_type_error:
wrappers.BaseResponse(None, tuple())
- assert 'Invalid status argument' in str(invalid_type_error)
+ assert "Invalid status argument" in str(invalid_type_error)
def test_type_forcing():
def wsgi_application(environ, start_response):
- start_response('200 OK', [('Content-Type', 'text/html')])
- return ['Hello World!']
- base_response = wrappers.BaseResponse('Hello World!', content_type='text/html')
+ start_response("200 OK", [("Content-Type", "text/html")])
+ return ["Hello World!"]
- class SpecialResponse(wrappers.Response):
+ base_response = wrappers.BaseResponse("Hello World!", content_type="text/html")
+ class SpecialResponse(wrappers.Response):
def foo(self):
return 42
@@ -358,54 +391,63 @@ def test_type_forcing():
response = SpecialResponse.force_type(orig_resp, fake_env)
assert response.__class__ is SpecialResponse
strict_eq(response.foo(), 42)
- strict_eq(response.get_data(), b'Hello World!')
- assert response.content_type == 'text/html'
+ strict_eq(response.get_data(), b"Hello World!")
+ assert response.content_type == "text/html"
# without env, no arbitrary conversion
pytest.raises(TypeError, SpecialResponse.force_type, wsgi_application)
def test_accept_mixin():
- request = wrappers.Request({
- 'HTTP_ACCEPT': 'text/xml,application/xml,application/xhtml+xml,'
- 'text/html;q=0.9,text/plain;q=0.8,image/png,*/*;q=0.5',
- 'HTTP_ACCEPT_CHARSET': 'ISO-8859-1,utf-8;q=0.7,*;q=0.7',
- 'HTTP_ACCEPT_ENCODING': 'gzip,deflate',
- 'HTTP_ACCEPT_LANGUAGE': 'en-us,en;q=0.5'
- })
- assert request.accept_mimetypes == MIMEAccept([
- ('text/xml', 1), ('image/png', 1), ('application/xml', 1),
- ('application/xhtml+xml', 1), ('text/html', 0.9),
- ('text/plain', 0.8), ('*/*', 0.5)
- ])
- strict_eq(request.accept_charsets, CharsetAccept([
- ('ISO-8859-1', 1), ('utf-8', 0.7), ('*', 0.7)
- ]))
- strict_eq(request.accept_encodings, Accept([
- ('gzip', 1), ('deflate', 1)]))
- strict_eq(request.accept_languages, LanguageAccept([
- ('en-us', 1), ('en', 0.5)]))
-
- request = wrappers.Request({'HTTP_ACCEPT': ''})
+ request = wrappers.Request(
+ {
+ "HTTP_ACCEPT": "text/xml,application/xml,application/xhtml+xml,"
+ "text/html;q=0.9,text/plain;q=0.8,image/png,*/*;q=0.5",
+ "HTTP_ACCEPT_CHARSET": "ISO-8859-1,utf-8;q=0.7,*;q=0.7",
+ "HTTP_ACCEPT_ENCODING": "gzip,deflate",
+ "HTTP_ACCEPT_LANGUAGE": "en-us,en;q=0.5",
+ }
+ )
+ assert request.accept_mimetypes == MIMEAccept(
+ [
+ ("text/xml", 1),
+ ("image/png", 1),
+ ("application/xml", 1),
+ ("application/xhtml+xml", 1),
+ ("text/html", 0.9),
+ ("text/plain", 0.8),
+ ("*/*", 0.5),
+ ]
+ )
+ strict_eq(
+ request.accept_charsets,
+ CharsetAccept([("ISO-8859-1", 1), ("utf-8", 0.7), ("*", 0.7)]),
+ )
+ strict_eq(request.accept_encodings, Accept([("gzip", 1), ("deflate", 1)]))
+ strict_eq(request.accept_languages, LanguageAccept([("en-us", 1), ("en", 0.5)]))
+
+ request = wrappers.Request({"HTTP_ACCEPT": ""})
strict_eq(request.accept_mimetypes, MIMEAccept())
def test_etag_request_mixin():
- request = wrappers.Request({
- 'HTTP_CACHE_CONTROL': 'no-store, no-cache',
- 'HTTP_IF_MATCH': 'W/"foo", bar, "baz"',
- 'HTTP_IF_NONE_MATCH': 'W/"foo", bar, "baz"',
- 'HTTP_IF_MODIFIED_SINCE': 'Tue, 22 Jan 2008 11:18:44 GMT',
- 'HTTP_IF_UNMODIFIED_SINCE': 'Tue, 22 Jan 2008 11:18:44 GMT'
- })
+ request = wrappers.Request(
+ {
+ "HTTP_CACHE_CONTROL": "no-store, no-cache",
+ "HTTP_IF_MATCH": 'W/"foo", bar, "baz"',
+ "HTTP_IF_NONE_MATCH": 'W/"foo", bar, "baz"',
+ "HTTP_IF_MODIFIED_SINCE": "Tue, 22 Jan 2008 11:18:44 GMT",
+ "HTTP_IF_UNMODIFIED_SINCE": "Tue, 22 Jan 2008 11:18:44 GMT",
+ }
+ )
assert request.cache_control.no_store
assert request.cache_control.no_cache
for etags in request.if_match, request.if_none_match:
- assert etags('bar')
+ assert etags("bar")
assert etags.contains_raw('W/"foo"')
- assert etags.contains_weak('foo')
- assert not etags.contains('foo')
+ assert etags.contains_weak("foo")
+ assert not etags.contains("foo")
assert request.if_modified_since == datetime(2008, 1, 22, 11, 18, 44)
assert request.if_unmodified_since == datetime(2008, 1, 22, 11, 18, 44)
@@ -413,60 +455,170 @@ def test_etag_request_mixin():
def test_user_agent_mixin():
user_agents = [
- ('Mozilla/5.0 (Macintosh; U; Intel Mac OS X; en-US; rv:1.8.1.11) '
- 'Gecko/20071127 Firefox/2.0.0.11', 'firefox', 'macos', '2.0.0.11',
- 'en-US'),
- ('Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; de-DE) Opera 8.54',
- 'opera', 'windows', '8.54', 'de-DE'),
- ('Mozilla/5.0 (iPhone; U; CPU like Mac OS X; en) AppleWebKit/420 '
- '(KHTML, like Gecko) Version/3.0 Mobile/1A543a Safari/419.3',
- 'safari', 'iphone', '3.0', 'en'),
- ('Bot Googlebot/2.1 ( http://www.googlebot.com/bot.html)',
- 'google', None, '2.1', None),
- ('Mozilla/5.0 (X11; CrOS armv7l 3701.81.0) AppleWebKit/537.31 '
- '(KHTML, like Gecko) Chrome/26.0.1410.57 Safari/537.31',
- 'chrome', 'chromeos', '26.0.1410.57', None),
- ('Mozilla/5.0 (Windows NT 6.3; Trident/7.0; .NET4.0E; rv:11.0) like Gecko',
- 'msie', 'windows', '11.0', None),
- ('Mozilla/5.0 (SymbianOS/9.3; Series60/3.2 NokiaE5-00/101.003; '
- 'Profile/MIDP-2.1 Configuration/CLDC-1.1 ) AppleWebKit/533.4 (KHTML, like Gecko) '
- 'NokiaBrowser/7.3.1.35 Mobile Safari/533.4 3gpp-gba',
- 'safari', 'symbian', '533.4', None),
- ('Mozilla/5.0 (X11; OpenBSD amd64; rv:45.0) Gecko/20100101 Firefox/45.0',
- 'firefox', 'openbsd', '45.0', None),
- ('Mozilla/5.0 (X11; NetBSD amd64; rv:45.0) Gecko/20100101 Firefox/45.0',
- 'firefox', 'netbsd', '45.0', None),
- ('Mozilla/5.0 (X11; FreeBSD amd64) AppleWebKit/537.36 (KHTML, like Gecko) '
- 'Chrome/48.0.2564.103 Safari/537.36',
- 'chrome', 'freebsd', '48.0.2564.103', None),
- ('Mozilla/5.0 (X11; FreeBSD amd64; rv:45.0) Gecko/20100101 Firefox/45.0',
- 'firefox', 'freebsd', '45.0', None),
- ('Mozilla/5.0 (X11; U; NetBSD amd64; en-US; rv:) Gecko/20150921 SeaMonkey/1.1.18',
- 'seamonkey', 'netbsd', '1.1.18', 'en-US'),
- ('Mozilla/5.0 (Windows; U; Windows NT 6.2; WOW64; rv:1.8.0.7) '
- 'Gecko/20110321 MultiZilla/4.33.2.6a SeaMonkey/8.6.55',
- 'seamonkey', 'windows', '8.6.55', None),
- ('Mozilla/5.0 (X11; Linux x86_64; rv:12.0) Gecko/20120427 Firefox/12.0 SeaMonkey/2.9',
- 'seamonkey', 'linux', '2.9', None),
- ('Mozilla/5.0 (compatible; Baiduspider/2.0; +http://www.baidu.com/search/spider.html)',
- 'baidu', None, '2.0', None),
- ('Mozilla/5.0 (X11; SunOS i86pc; rv:38.0) Gecko/20100101 Firefox/38.0',
- 'firefox', 'solaris', '38.0', None),
- ('Mozilla/5.0 (X11; Linux x86_64; rv:38.0) Gecko/20100101 Firefox/38.0 Iceweasel/38.7.1',
- 'firefox', 'linux', '38.0', None),
- ('Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) '
- 'Chrome/50.0.2661.75 Safari/537.36',
- 'chrome', 'windows', '50.0.2661.75', None),
- ('Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)',
- 'bing', None, '2.0', None),
- ('Mozilla/5.0 (X11; DragonFly x86_64) AppleWebKit/537.36 (KHTML, like Gecko) '
- 'Chrome/47.0.2526.106 Safari/537.36', 'chrome', 'dragonflybsd', '47.0.2526.106', None),
- ('Mozilla/5.0 (X11; U; DragonFly i386; de; rv:1.9.1) Gecko/20090720 Firefox/3.5.1',
- 'firefox', 'dragonflybsd', '3.5.1', 'de')
-
+ (
+ "Mozilla/5.0 (Macintosh; U; Intel Mac OS X; en-US; rv:1.8.1.11) "
+ "Gecko/20071127 Firefox/2.0.0.11",
+ "firefox",
+ "macos",
+ "2.0.0.11",
+ "en-US",
+ ),
+ (
+ "Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; de-DE) Opera 8.54",
+ "opera",
+ "windows",
+ "8.54",
+ "de-DE",
+ ),
+ (
+ "Mozilla/5.0 (iPhone; U; CPU like Mac OS X; en) AppleWebKit/420 "
+ "(KHTML, like Gecko) Version/3.0 Mobile/1A543a Safari/419.3",
+ "safari",
+ "iphone",
+ "3.0",
+ "en",
+ ),
+ (
+ "Bot Googlebot/2.1 ( http://www.googlebot.com/bot.html)",
+ "google",
+ None,
+ "2.1",
+ None,
+ ),
+ (
+ "Mozilla/5.0 (X11; CrOS armv7l 3701.81.0) AppleWebKit/537.31 "
+ "(KHTML, like Gecko) Chrome/26.0.1410.57 Safari/537.31",
+ "chrome",
+ "chromeos",
+ "26.0.1410.57",
+ None,
+ ),
+ (
+ "Mozilla/5.0 (Windows NT 6.3; Trident/7.0; .NET4.0E; rv:11.0) like Gecko",
+ "msie",
+ "windows",
+ "11.0",
+ None,
+ ),
+ (
+ "Mozilla/5.0 (SymbianOS/9.3; Series60/3.2 NokiaE5-00/101.003; "
+ "Profile/MIDP-2.1 Configuration/CLDC-1.1 ) AppleWebKit/533.4 "
+ "(KHTML, like Gecko) NokiaBrowser/7.3.1.35 Mobile Safari/533.4 3gpp-gba",
+ "safari",
+ "symbian",
+ "533.4",
+ None,
+ ),
+ (
+ "Mozilla/5.0 (X11; OpenBSD amd64; rv:45.0) Gecko/20100101 Firefox/45.0",
+ "firefox",
+ "openbsd",
+ "45.0",
+ None,
+ ),
+ (
+ "Mozilla/5.0 (X11; NetBSD amd64; rv:45.0) Gecko/20100101 Firefox/45.0",
+ "firefox",
+ "netbsd",
+ "45.0",
+ None,
+ ),
+ (
+ "Mozilla/5.0 (X11; FreeBSD amd64) AppleWebKit/537.36 (KHTML, like Gecko) "
+ "Chrome/48.0.2564.103 Safari/537.36",
+ "chrome",
+ "freebsd",
+ "48.0.2564.103",
+ None,
+ ),
+ (
+ "Mozilla/5.0 (X11; FreeBSD amd64; rv:45.0) Gecko/20100101 Firefox/45.0",
+ "firefox",
+ "freebsd",
+ "45.0",
+ None,
+ ),
+ (
+ "Mozilla/5.0 (X11; U; NetBSD amd64; en-US; rv:) Gecko/20150921 "
+ "SeaMonkey/1.1.18",
+ "seamonkey",
+ "netbsd",
+ "1.1.18",
+ "en-US",
+ ),
+ (
+ "Mozilla/5.0 (Windows; U; Windows NT 6.2; WOW64; rv:1.8.0.7) "
+ "Gecko/20110321 MultiZilla/4.33.2.6a SeaMonkey/8.6.55",
+ "seamonkey",
+ "windows",
+ "8.6.55",
+ None,
+ ),
+ (
+ "Mozilla/5.0 (X11; Linux x86_64; rv:12.0) Gecko/20120427 Firefox/12.0 "
+ "SeaMonkey/2.9",
+ "seamonkey",
+ "linux",
+ "2.9",
+ None,
+ ),
+ (
+ "Mozilla/5.0 (compatible; Baiduspider/2.0; "
+ "+http://www.baidu.com/search/spider.html)",
+ "baidu",
+ None,
+ "2.0",
+ None,
+ ),
+ (
+ "Mozilla/5.0 (X11; SunOS i86pc; rv:38.0) Gecko/20100101 Firefox/38.0",
+ "firefox",
+ "solaris",
+ "38.0",
+ None,
+ ),
+ (
+ "Mozilla/5.0 (X11; Linux x86_64; rv:38.0) Gecko/20100101 Firefox/38.0 "
+ "Iceweasel/38.7.1",
+ "firefox",
+ "linux",
+ "38.0",
+ None,
+ ),
+ (
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
+ "(KHTML, like Gecko) Chrome/50.0.2661.75 Safari/537.36",
+ "chrome",
+ "windows",
+ "50.0.2661.75",
+ None,
+ ),
+ (
+ "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)",
+ "bing",
+ None,
+ "2.0",
+ None,
+ ),
+ (
+ "Mozilla/5.0 (X11; DragonFly x86_64) AppleWebKit/537.36 "
+ "(KHTML, like Gecko) Chrome/47.0.2526.106 Safari/537.36",
+ "chrome",
+ "dragonflybsd",
+ "47.0.2526.106",
+ None,
+ ),
+ (
+ "Mozilla/5.0 (X11; U; DragonFly i386; de; rv:1.9.1) "
+ "Gecko/20090720 Firefox/3.5.1",
+ "firefox",
+ "dragonflybsd",
+ "3.5.1",
+ "de",
+ ),
]
for ua, browser, platform, version, lang in user_agents:
- request = wrappers.Request({'HTTP_USER_AGENT': ua})
+ request = wrappers.Request({"HTTP_USER_AGENT": ua})
strict_eq(request.user_agent.browser, browser)
strict_eq(request.user_agent.platform, platform)
strict_eq(request.user_agent.version, version)
@@ -475,13 +627,12 @@ def test_user_agent_mixin():
strict_eq(request.user_agent.to_header(), ua)
strict_eq(str(request.user_agent), ua)
- request = wrappers.Request({'HTTP_USER_AGENT': 'foo'})
+ request = wrappers.Request({"HTTP_USER_AGENT": "foo"})
assert not request.user_agent
def test_stream_wrapping():
class LowercasingStream(object):
-
def __init__(self, stream):
self._stream = stream
@@ -491,121 +642,121 @@ def test_stream_wrapping():
def readline(self, size=-1):
return self._stream.readline(size).lower()
- data = b'foo=Hello+World'
+ data = b"foo=Hello+World"
req = wrappers.Request.from_values(
- '/', method='POST', data=data,
- content_type='application/x-www-form-urlencoded')
+ "/", method="POST", data=data, content_type="application/x-www-form-urlencoded"
+ )
req.stream = LowercasingStream(req.stream)
- assert req.form['foo'] == 'hello world'
+ assert req.form["foo"] == "hello world"
def test_data_descriptor_triggers_parsing():
- data = b'foo=Hello+World'
+ data = b"foo=Hello+World"
req = wrappers.Request.from_values(
- '/', method='POST', data=data,
- content_type='application/x-www-form-urlencoded')
+ "/", method="POST", data=data, content_type="application/x-www-form-urlencoded"
+ )
- assert req.data == b''
- assert req.form['foo'] == u'Hello World'
+ assert req.data == b""
+ assert req.form["foo"] == u"Hello World"
def test_get_data_method_parsing_caching_behavior():
- data = b'foo=Hello+World'
+ data = b"foo=Hello+World"
req = wrappers.Request.from_values(
- '/', method='POST', data=data,
- content_type='application/x-www-form-urlencoded')
+ "/", method="POST", data=data, content_type="application/x-www-form-urlencoded"
+ )
# get_data() caches, so form stays available
assert req.get_data() == data
- assert req.form['foo'] == u'Hello World'
+ assert req.form["foo"] == u"Hello World"
assert req.get_data() == data
# here we access the form data first, caching is bypassed
req = wrappers.Request.from_values(
- '/', method='POST', data=data,
- content_type='application/x-www-form-urlencoded')
- assert req.form['foo'] == u'Hello World'
- assert req.get_data() == b''
+ "/", method="POST", data=data, content_type="application/x-www-form-urlencoded"
+ )
+ assert req.form["foo"] == u"Hello World"
+ assert req.get_data() == b""
# Another case is uncached get data which trashes everything
req = wrappers.Request.from_values(
- '/', method='POST', data=data,
- content_type='application/x-www-form-urlencoded')
+ "/", method="POST", data=data, content_type="application/x-www-form-urlencoded"
+ )
assert req.get_data(cache=False) == data
- assert req.get_data(cache=False) == b''
+ assert req.get_data(cache=False) == b""
assert req.form == {}
# Or we can implicitly start the form parser which is similar to
# the old .data behavior
req = wrappers.Request.from_values(
- '/', method='POST', data=data,
- content_type='application/x-www-form-urlencoded')
- assert req.get_data(parse_form_data=True) == b''
- assert req.form['foo'] == u'Hello World'
+ "/", method="POST", data=data, content_type="application/x-www-form-urlencoded"
+ )
+ assert req.get_data(parse_form_data=True) == b""
+ assert req.form["foo"] == u"Hello World"
def test_etag_response_mixin():
- response = wrappers.Response('Hello World')
+ response = wrappers.Response("Hello World")
assert response.get_etag() == (None, None)
response.add_etag()
- assert response.get_etag() == ('b10a8db164e0754105b7a99be72e3fe5', False)
+ assert response.get_etag() == ("b10a8db164e0754105b7a99be72e3fe5", False)
assert not response.cache_control
response.cache_control.must_revalidate = True
response.cache_control.max_age = 60
- response.headers['Content-Length'] = len(response.get_data())
- assert response.headers['Cache-Control'] in ('must-revalidate, max-age=60',
- 'max-age=60, must-revalidate')
+ response.headers["Content-Length"] = len(response.get_data())
+ assert response.headers["Cache-Control"] in (
+ "must-revalidate, max-age=60",
+ "max-age=60, must-revalidate",
+ )
- assert 'date' not in response.headers
+ assert "date" not in response.headers
env = create_environ()
- env.update({
- 'REQUEST_METHOD': 'GET',
- 'HTTP_IF_NONE_MATCH': response.get_etag()[0]
- })
+ env.update({"REQUEST_METHOD": "GET", "HTTP_IF_NONE_MATCH": response.get_etag()[0]})
response.make_conditional(env)
- assert 'date' in response.headers
+ assert "date" in response.headers
# after the thing is invoked by the server as wsgi application
# (we're emulating this here), there must not be any entity
# headers left and the status code would have to be 304
resp = wrappers.Response.from_app(response, env)
assert resp.status_code == 304
- assert 'content-length' not in resp.headers
+ assert "content-length" not in resp.headers
# make sure date is not overriden
- response = wrappers.Response('Hello World')
+ response = wrappers.Response("Hello World")
response.date = 1337
d = response.date
response.make_conditional(env)
assert response.date == d
# make sure content length is only set if missing
- response = wrappers.Response('Hello World')
+ response = wrappers.Response("Hello World")
response.content_length = 999
response.make_conditional(env)
assert response.content_length == 999
def test_etag_response_412():
- response = wrappers.Response('Hello World')
+ response = wrappers.Response("Hello World")
assert response.get_etag() == (None, None)
response.add_etag()
- assert response.get_etag() == ('b10a8db164e0754105b7a99be72e3fe5', False)
+ assert response.get_etag() == ("b10a8db164e0754105b7a99be72e3fe5", False)
assert not response.cache_control
response.cache_control.must_revalidate = True
response.cache_control.max_age = 60
- response.headers['Content-Length'] = len(response.get_data())
- assert response.headers['Cache-Control'] in ('must-revalidate, max-age=60',
- 'max-age=60, must-revalidate')
+ response.headers["Content-Length"] = len(response.get_data())
+ assert response.headers["Cache-Control"] in (
+ "must-revalidate, max-age=60",
+ "max-age=60, must-revalidate",
+ )
- assert 'date' not in response.headers
+ assert "date" not in response.headers
env = create_environ()
- env.update({
- 'REQUEST_METHOD': 'GET',
- 'HTTP_IF_MATCH': response.get_etag()[0] + "xyz"
- })
+ env.update(
+ {"REQUEST_METHOD": "GET", "HTTP_IF_MATCH": response.get_etag()[0] + "xyz"}
+ )
response.make_conditional(env)
- assert 'date' in response.headers
+ assert "date" in response.headers
# after the thing is invoked by the server as wsgi application
# (we're emulating this here), there must not be any entity
@@ -613,17 +764,17 @@ def test_etag_response_412():
resp = wrappers.Response.from_app(response, env)
assert resp.status_code == 412
# Make sure there is a body still
- assert resp.data != b''
+ assert resp.data != b""
# make sure date is not overriden
- response = wrappers.Response('Hello World')
+ response = wrappers.Response("Hello World")
response.date = 1337
d = response.date
response.make_conditional(env)
assert response.date == d
# make sure content length is only set if missing
- response = wrappers.Response('Hello World')
+ response = wrappers.Response("Hello World")
response.content_length = 999
response.make_conditional(env)
assert response.content_length == 999
@@ -631,77 +782,78 @@ def test_etag_response_412():
def test_range_request_basic():
env = create_environ()
- response = wrappers.Response('Hello World')
- env['HTTP_RANGE'] = 'bytes=0-4'
+ response = wrappers.Response("Hello World")
+ env["HTTP_RANGE"] = "bytes=0-4"
response.make_conditional(env, accept_ranges=True, complete_length=11)
assert response.status_code == 206
- assert response.headers['Accept-Ranges'] == 'bytes'
- assert response.headers['Content-Range'] == 'bytes 0-4/11'
- assert response.headers['Content-Length'] == '5'
- assert response.data == b'Hello'
+ assert response.headers["Accept-Ranges"] == "bytes"
+ assert response.headers["Content-Range"] == "bytes 0-4/11"
+ assert response.headers["Content-Length"] == "5"
+ assert response.data == b"Hello"
def test_range_request_out_of_bound():
env = create_environ()
- response = wrappers.Response('Hello World')
- env['HTTP_RANGE'] = 'bytes=6-666'
+ response = wrappers.Response("Hello World")
+ env["HTTP_RANGE"] = "bytes=6-666"
response.make_conditional(env, accept_ranges=True, complete_length=11)
assert response.status_code == 206
- assert response.headers['Accept-Ranges'] == 'bytes'
- assert response.headers['Content-Range'] == 'bytes 6-10/11'
- assert response.headers['Content-Length'] == '5'
- assert response.data == b'World'
+ assert response.headers["Accept-Ranges"] == "bytes"
+ assert response.headers["Content-Range"] == "bytes 6-10/11"
+ assert response.headers["Content-Length"] == "5"
+ assert response.data == b"World"
def test_range_request_with_file():
env = create_environ()
- resources = os.path.join(os.path.dirname(__file__), 'res')
- fname = os.path.join(resources, 'test.txt')
- with open(fname, 'rb') as f:
+ resources = os.path.join(os.path.dirname(__file__), "res")
+ fname = os.path.join(resources, "test.txt")
+ with open(fname, "rb") as f:
fcontent = f.read()
- with open(fname, 'rb') as f:
+ with open(fname, "rb") as f:
response = wrappers.Response(wrap_file(env, f))
- env['HTTP_RANGE'] = 'bytes=0-0'
- response.make_conditional(env, accept_ranges=True, complete_length=len(fcontent))
+ env["HTTP_RANGE"] = "bytes=0-0"
+ response.make_conditional(
+ env, accept_ranges=True, complete_length=len(fcontent)
+ )
assert response.status_code == 206
- assert response.headers['Accept-Ranges'] == 'bytes'
- assert response.headers['Content-Range'] == 'bytes 0-0/%d' % len(fcontent)
- assert response.headers['Content-Length'] == '1'
+ assert response.headers["Accept-Ranges"] == "bytes"
+ assert response.headers["Content-Range"] == "bytes 0-0/%d" % len(fcontent)
+ assert response.headers["Content-Length"] == "1"
assert response.data == fcontent[:1]
def test_range_request_with_complete_file():
env = create_environ()
- resources = os.path.join(os.path.dirname(__file__), 'res')
- fname = os.path.join(resources, 'test.txt')
- with open(fname, 'rb') as f:
+ resources = os.path.join(os.path.dirname(__file__), "res")
+ fname = os.path.join(resources, "test.txt")
+ with open(fname, "rb") as f:
fcontent = f.read()
- with open(fname, 'rb') as f:
+ with open(fname, "rb") as f:
fsize = os.path.getsize(fname)
response = wrappers.Response(wrap_file(env, f))
- env['HTTP_RANGE'] = 'bytes=0-%d' % (fsize - 1)
- response.make_conditional(env, accept_ranges=True,
- complete_length=fsize)
+ env["HTTP_RANGE"] = "bytes=0-%d" % (fsize - 1)
+ response.make_conditional(env, accept_ranges=True, complete_length=fsize)
assert response.status_code == 200
- assert response.headers['Accept-Ranges'] == 'bytes'
- assert 'Content-Range' not in response.headers
- assert response.headers['Content-Length'] == str(fsize)
+ assert response.headers["Accept-Ranges"] == "bytes"
+ assert "Content-Range" not in response.headers
+ assert response.headers["Content-Length"] == str(fsize)
assert response.data == fcontent
def test_range_request_without_complete_length():
env = create_environ()
- response = wrappers.Response('Hello World')
- env['HTTP_RANGE'] = 'bytes=-'
+ response = wrappers.Response("Hello World")
+ env["HTTP_RANGE"] = "bytes=-"
response.make_conditional(env, accept_ranges=True, complete_length=None)
assert response.status_code == 200
- assert response.data == b'Hello World'
+ assert response.data == b"Hello World"
def test_invalid_range_request():
env = create_environ()
- response = wrappers.Response('Hello World')
- env['HTTP_RANGE'] = 'bytes=-'
+ response = wrappers.Response("Hello World")
+ env["HTTP_RANGE"] = "bytes=-"
with pytest.raises(RequestedRangeNotSatisfiable):
response.make_conditional(env, accept_ranges=True, complete_length=11)
@@ -713,69 +865,68 @@ def test_etag_response_mixin_freezing():
class WithoutFreeze(wrappers.BaseResponse, wrappers.ETagResponseMixin):
pass
- response = WithFreeze('Hello World')
+ response = WithFreeze("Hello World")
response.freeze()
- strict_eq(response.get_etag(),
- (text_type(generate_etag(b'Hello World')), False))
- response = WithoutFreeze('Hello World')
+ strict_eq(response.get_etag(), (text_type(generate_etag(b"Hello World")), False))
+ response = WithoutFreeze("Hello World")
response.freeze()
assert response.get_etag() == (None, None)
- response = wrappers.Response('Hello World')
+ response = wrappers.Response("Hello World")
response.freeze()
assert response.get_etag() == (None, None)
def test_authenticate_mixin():
resp = wrappers.Response()
- resp.www_authenticate.type = 'basic'
- resp.www_authenticate.realm = 'Testing'
- strict_eq(resp.headers['WWW-Authenticate'], u'Basic realm="Testing"')
+ resp.www_authenticate.type = "basic"
+ resp.www_authenticate.realm = "Testing"
+ strict_eq(resp.headers["WWW-Authenticate"], u'Basic realm="Testing"')
resp.www_authenticate.realm = None
resp.www_authenticate.type = None
- assert 'WWW-Authenticate' not in resp.headers
+ assert "WWW-Authenticate" not in resp.headers
def test_authenticate_mixin_quoted_qop():
# Example taken from https://github.com/pallets/werkzeug/issues/633
resp = wrappers.Response()
- resp.www_authenticate.set_digest('REALM', 'NONCE', qop=("auth", "auth-int"))
+ resp.www_authenticate.set_digest("REALM", "NONCE", qop=("auth", "auth-int"))
- actual = set((resp.headers['WWW-Authenticate'] + ',').split())
+ actual = set((resp.headers["WWW-Authenticate"] + ",").split())
expected = set('Digest nonce="NONCE", realm="REALM", qop="auth, auth-int",'.split())
assert actual == expected
- resp.www_authenticate.set_digest('REALM', 'NONCE', qop=("auth",))
+ resp.www_authenticate.set_digest("REALM", "NONCE", qop=("auth",))
- actual = set((resp.headers['WWW-Authenticate'] + ',').split())
+ actual = set((resp.headers["WWW-Authenticate"] + ",").split())
expected = set('Digest nonce="NONCE", realm="REALM", qop="auth",'.split())
assert actual == expected
def test_response_stream_mixin():
response = wrappers.Response()
- response.stream.write('Hello ')
- response.stream.write('World!')
- assert response.response == ['Hello ', 'World!']
- assert response.get_data() == b'Hello World!'
+ response.stream.write("Hello ")
+ response.stream.write("World!")
+ assert response.response == ["Hello ", "World!"]
+ assert response.get_data() == b"Hello World!"
def test_common_response_descriptors_mixin():
response = wrappers.Response()
- response.mimetype = 'text/html'
- assert response.mimetype == 'text/html'
- assert response.content_type == 'text/html; charset=utf-8'
- assert response.mimetype_params == {'charset': 'utf-8'}
- response.mimetype_params['x-foo'] = 'yep'
- del response.mimetype_params['charset']
- assert response.content_type == 'text/html; x-foo=yep'
+ response.mimetype = "text/html"
+ assert response.mimetype == "text/html"
+ assert response.content_type == "text/html; charset=utf-8"
+ assert response.mimetype_params == {"charset": "utf-8"}
+ response.mimetype_params["x-foo"] = "yep"
+ del response.mimetype_params["charset"]
+ assert response.content_type == "text/html; x-foo=yep"
now = datetime.utcnow().replace(microsecond=0)
assert response.content_length is None
- response.content_length = '42'
+ response.content_length = "42"
assert response.content_length == 42
- for attr in 'date', 'expires':
+ for attr in "date", "expires":
assert getattr(response, attr) is None
setattr(response, attr, now)
assert getattr(response, attr) == now
@@ -792,95 +943,97 @@ def test_common_response_descriptors_mixin():
assert response.retry_after == now
assert not response.vary
- response.vary.add('Cookie')
- response.vary.add('Content-Language')
- assert 'cookie' in response.vary
- assert response.vary.to_header() == 'Cookie, Content-Language'
- response.headers['Vary'] = 'Content-Encoding'
- assert response.vary.as_set() == set(['content-encoding'])
+ response.vary.add("Cookie")
+ response.vary.add("Content-Language")
+ assert "cookie" in response.vary
+ assert response.vary.to_header() == "Cookie, Content-Language"
+ response.headers["Vary"] = "Content-Encoding"
+ assert response.vary.as_set() == {"content-encoding"}
- response.allow.update(['GET', 'POST'])
- assert response.headers['Allow'] == 'GET, POST'
+ response.allow.update(["GET", "POST"])
+ assert response.headers["Allow"] == "GET, POST"
- response.content_language.add('en-US')
- response.content_language.add('fr')
- assert response.headers['Content-Language'] == 'en-US, fr'
+ response.content_language.add("en-US")
+ response.content_language.add("fr")
+ assert response.headers["Content-Language"] == "en-US, fr"
def test_common_request_descriptors_mixin():
request = wrappers.Request.from_values(
- content_type='text/html; charset=utf-8',
- content_length='23',
+ content_type="text/html; charset=utf-8",
+ content_length="23",
headers={
- 'Referer': 'http://www.example.com/',
- 'Date': 'Sat, 28 Feb 2009 19:04:35 GMT',
- 'Max-Forwards': '10',
- 'Pragma': 'no-cache',
- 'Content-Encoding': 'gzip',
- 'Content-MD5': '9a3bc6dbc47a70db25b84c6e5867a072'
- }
+ "Referer": "http://www.example.com/",
+ "Date": "Sat, 28 Feb 2009 19:04:35 GMT",
+ "Max-Forwards": "10",
+ "Pragma": "no-cache",
+ "Content-Encoding": "gzip",
+ "Content-MD5": "9a3bc6dbc47a70db25b84c6e5867a072",
+ },
)
- assert request.content_type == 'text/html; charset=utf-8'
- assert request.mimetype == 'text/html'
- assert request.mimetype_params == {'charset': 'utf-8'}
+ assert request.content_type == "text/html; charset=utf-8"
+ assert request.mimetype == "text/html"
+ assert request.mimetype_params == {"charset": "utf-8"}
assert request.content_length == 23
- assert request.referrer == 'http://www.example.com/'
+ assert request.referrer == "http://www.example.com/"
assert request.date == datetime(2009, 2, 28, 19, 4, 35)
assert request.max_forwards == 10
- assert 'no-cache' in request.pragma
- assert request.content_encoding == 'gzip'
- assert request.content_md5 == '9a3bc6dbc47a70db25b84c6e5867a072'
+ assert "no-cache" in request.pragma
+ assert request.content_encoding == "gzip"
+ assert request.content_md5 == "9a3bc6dbc47a70db25b84c6e5867a072"
def test_request_mimetype_always_lowercase():
- request = wrappers.Request.from_values(content_type='APPLICATION/JSON')
- assert request.mimetype == 'application/json'
+ request = wrappers.Request.from_values(content_type="APPLICATION/JSON")
+ assert request.mimetype == "application/json"
def test_shallow_mode():
- request = wrappers.Request({'QUERY_STRING': 'foo=bar'}, shallow=True)
- assert request.args['foo'] == 'bar'
- pytest.raises(RuntimeError, lambda: request.form['foo'])
+ request = wrappers.Request({"QUERY_STRING": "foo=bar"}, shallow=True)
+ assert request.args["foo"] == "bar"
+ pytest.raises(RuntimeError, lambda: request.form["foo"])
def test_form_parsing_failed():
- data = b'--blah\r\n'
+ data = b"--blah\r\n"
request = wrappers.Request.from_values(
input_stream=BytesIO(data),
content_length=len(data),
- content_type='multipart/form-data; boundary=foo',
- method='POST'
+ content_type="multipart/form-data; boundary=foo",
+ method="POST",
)
assert not request.files
assert not request.form
# Bad Content-Type
- data = b'test'
+ data = b"test"
request = wrappers.Request.from_values(
input_stream=BytesIO(data),
content_length=len(data),
- content_type=', ',
- method='POST'
+ content_type=", ",
+ method="POST",
)
assert not request.form
def test_file_closing():
- data = (b'--foo\r\n'
- b'Content-Disposition: form-data; name="foo"; filename="foo.txt"\r\n'
- b'Content-Type: text/plain; charset=utf-8\r\n\r\n'
- b'file contents, just the contents\r\n'
- b'--foo--')
+ data = (
+ b"--foo\r\n"
+ b'Content-Disposition: form-data; name="foo"; filename="foo.txt"\r\n'
+ b"Content-Type: text/plain; charset=utf-8\r\n\r\n"
+ b"file contents, just the contents\r\n"
+ b"--foo--"
+ )
req = wrappers.Request.from_values(
input_stream=BytesIO(data),
content_length=len(data),
- content_type='multipart/form-data; boundary=foo',
- method='POST'
+ content_type="multipart/form-data; boundary=foo",
+ method="POST",
)
- foo = req.files['foo']
- assert foo.mimetype == 'text/plain'
- assert foo.filename == 'foo.txt'
+ foo = req.files["foo"]
+ assert foo.mimetype == "text/plain"
+ assert foo.filename == "foo.txt"
assert foo.closed is False
req.close()
@@ -888,29 +1041,31 @@ def test_file_closing():
def test_file_closing_with():
- data = (b'--foo\r\n'
- b'Content-Disposition: form-data; name="foo"; filename="foo.txt"\r\n'
- b'Content-Type: text/plain; charset=utf-8\r\n\r\n'
- b'file contents, just the contents\r\n'
- b'--foo--')
+ data = (
+ b"--foo\r\n"
+ b'Content-Disposition: form-data; name="foo"; filename="foo.txt"\r\n'
+ b"Content-Type: text/plain; charset=utf-8\r\n\r\n"
+ b"file contents, just the contents\r\n"
+ b"--foo--"
+ )
req = wrappers.Request.from_values(
input_stream=BytesIO(data),
content_length=len(data),
- content_type='multipart/form-data; boundary=foo',
- method='POST'
+ content_type="multipart/form-data; boundary=foo",
+ method="POST",
)
with req:
- foo = req.files['foo']
- assert foo.mimetype == 'text/plain'
- assert foo.filename == 'foo.txt'
+ foo = req.files["foo"]
+ assert foo.mimetype == "text/plain"
+ assert foo.filename == "foo.txt"
assert foo.closed is True
def test_url_charset_reflection():
req = wrappers.Request.from_values()
- req.charset = 'utf-7'
- assert req.url_charset == 'utf-7'
+ req.charset = "utf-7"
+ assert req.url_charset == "utf-7"
def test_response_streamed():
@@ -924,6 +1079,7 @@ def test_response_streamed():
def gen():
if 0:
yield None
+
r = wrappers.Response(gen())
assert r.is_streamed
@@ -934,58 +1090,61 @@ def test_response_iter_wrapping():
yield item.upper()
def generator():
- yield 'foo'
- yield 'bar'
+ yield "foo"
+ yield "bar"
+
req = wrappers.Request.from_values()
resp = wrappers.Response(generator())
- del resp.headers['Content-Length']
+ del resp.headers["Content-Length"]
resp.response = uppercasing(resp.iter_encoded())
actual_resp = wrappers.Response.from_app(resp, req.environ, buffered=True)
- assert actual_resp.get_data() == b'FOOBAR'
+ assert actual_resp.get_data() == b"FOOBAR"
def test_response_freeze():
def generate():
yield "foo"
yield "bar"
+
resp = wrappers.Response(generate())
resp.freeze()
- assert resp.response == [b'foo', b'bar']
- assert resp.headers['content-length'] == '6'
+ assert resp.response == [b"foo", b"bar"]
+ assert resp.headers["content-length"] == "6"
def test_response_content_length_uses_encode():
- r = wrappers.Response(u'你好')
+ r = wrappers.Response(u"你好")
assert r.calculate_content_length() == 6
def test_other_method_payload():
- data = b'Hello World'
- req = wrappers.Request.from_values(input_stream=BytesIO(data),
- content_length=len(data),
- content_type='text/plain',
- method='WHAT_THE_FUCK')
+ data = b"Hello World"
+ req = wrappers.Request.from_values(
+ input_stream=BytesIO(data),
+ content_length=len(data),
+ content_type="text/plain",
+ method="WHAT_THE_FUCK",
+ )
assert req.get_data() == data
assert isinstance(req.stream, LimitedStream)
def test_urlfication():
resp = wrappers.Response()
- resp.headers['Location'] = u'http://üser:pässword@☃.net/påth'
- resp.headers['Content-Location'] = u'http://☃.net/'
+ resp.headers["Location"] = u"http://üser:pässword@☃.net/påth"
+ resp.headers["Content-Location"] = u"http://☃.net/"
headers = resp.get_wsgi_headers(create_environ())
- assert headers['location'] == \
- 'http://%C3%BCser:p%C3%A4ssword@xn--n3h.net/p%C3%A5th'
- assert headers['content-location'] == 'http://xn--n3h.net/'
+ assert headers["location"] == "http://%C3%BCser:p%C3%A4ssword@xn--n3h.net/p%C3%A5th"
+ assert headers["content-location"] == "http://xn--n3h.net/"
def test_new_response_iterator_behavior():
req = wrappers.Request.from_values()
- resp = wrappers.Response(u'Hello Wörld!')
+ resp = wrappers.Response(u"Hello Wörld!")
def get_content_length(resp):
headers = resp.get_wsgi_headers(req.environ)
- return headers.get('content-length', type=int)
+ return headers.get("content-length", type=int)
def generate_items():
yield "Hello "
@@ -993,16 +1152,16 @@ def test_new_response_iterator_behavior():
# werkzeug encodes when set to `data` now, which happens
# if a string is passed to the response object.
- assert resp.response == [u'Hello Wörld!'.encode('utf-8')]
- assert resp.get_data() == u'Hello Wörld!'.encode('utf-8')
+ assert resp.response == [u"Hello Wörld!".encode("utf-8")]
+ assert resp.get_data() == u"Hello Wörld!".encode("utf-8")
assert get_content_length(resp) == 13
assert not resp.is_streamed
assert resp.is_sequence
# try the same for manual assignment
- resp.set_data(u'Wörd')
- assert resp.response == [u'Wörd'.encode('utf-8')]
- assert resp.get_data() == u'Wörd'.encode('utf-8')
+ resp.set_data(u"Wörd")
+ assert resp.response == [u"Wörd".encode("utf-8")]
+ assert resp.get_data() == u"Wörd".encode("utf-8")
assert get_content_length(resp) == 5
assert not resp.is_streamed
assert resp.is_sequence
@@ -1011,8 +1170,8 @@ def test_new_response_iterator_behavior():
resp.response = generate_items()
assert resp.is_streamed
assert not resp.is_sequence
- assert resp.get_data() == u'Hello Wörld!'.encode('utf-8')
- assert resp.response == [b'Hello ', u'Wörld!'.encode('utf-8')]
+ assert resp.get_data() == u"Hello Wörld!".encode("utf-8")
+ assert resp.response == [b"Hello ", u"Wörld!".encode("utf-8")]
assert not resp.is_streamed
assert resp.is_sequence
@@ -1023,8 +1182,8 @@ def test_new_response_iterator_behavior():
assert not resp.is_sequence
pytest.raises(RuntimeError, lambda: resp.get_data())
resp.make_sequence()
- assert resp.get_data() == u'Hello Wörld!'.encode('utf-8')
- assert resp.response == [b'Hello ', u'Wörld!'.encode('utf-8')]
+ assert resp.get_data() == u"Hello Wörld!".encode("utf-8")
+ assert resp.response == [b"Hello ", u"Wörld!".encode("utf-8")]
assert not resp.is_streamed
assert resp.is_sequence
@@ -1033,25 +1192,25 @@ def test_new_response_iterator_behavior():
resp.implicit_sequence_conversion = val
resp.response = ("foo", "bar")
assert resp.is_sequence
- resp.stream.write('baz')
- assert resp.response == ['foo', 'bar', 'baz']
+ resp.stream.write("baz")
+ assert resp.response == ["foo", "bar", "baz"]
def test_form_data_ordering():
class MyRequest(wrappers.Request):
parameter_storage_class = ImmutableOrderedMultiDict
- req = MyRequest.from_values('/?foo=1&bar=0&foo=3')
- assert list(req.args) == ['foo', 'bar']
+ req = MyRequest.from_values("/?foo=1&bar=0&foo=3")
+ assert list(req.args) == ["foo", "bar"]
assert list(req.args.items(multi=True)) == [
- ('foo', '1'),
- ('bar', '0'),
- ('foo', '3')
+ ("foo", "1"),
+ ("bar", "0"),
+ ("foo", "3"),
]
assert isinstance(req.args, ImmutableOrderedMultiDict)
assert isinstance(req.values, CombinedMultiDict)
- assert req.values['foo'] == '1'
- assert req.values.getlist('foo') == ['1', '3']
+ assert req.values["foo"] == "1"
+ assert req.values.getlist("foo") == ["1", "3"]
def test_storage_classes():
@@ -1059,22 +1218,19 @@ def test_storage_classes():
dict_storage_class = dict
list_storage_class = list
parameter_storage_class = dict
- req = MyRequest.from_values('/?foo=baz', headers={
- 'Cookie': 'foo=bar'
- })
+
+ req = MyRequest.from_values("/?foo=baz", headers={"Cookie": "foo=bar"})
assert type(req.cookies) is dict
- assert req.cookies == {'foo': 'bar'}
+ assert req.cookies == {"foo": "bar"}
assert type(req.access_route) is list
assert type(req.args) is dict
assert type(req.values) is CombinedMultiDict
- assert req.values['foo'] == u'baz'
+ assert req.values["foo"] == u"baz"
- req = wrappers.Request.from_values(headers={
- 'Cookie': 'foo=bar'
- })
+ req = wrappers.Request.from_values(headers={"Cookie": "foo=bar"})
assert type(req.cookies) is ImmutableTypeConversionDict
- assert req.cookies == {'foo': 'bar'}
+ assert req.cookies == {"foo": "bar"}
assert type(req.access_route) is ImmutableList
MyRequest.list_storage_class = tuple
@@ -1089,91 +1245,93 @@ def test_response_headers_passthrough():
def test_response_304_no_content_length():
- resp = wrappers.Response('Test', status=304)
+ resp = wrappers.Response("Test", status=304)
env = create_environ()
- assert 'content-length' not in resp.get_wsgi_headers(env)
+ assert "content-length" not in resp.get_wsgi_headers(env)
def test_ranges():
# basic range stuff
req = wrappers.Request.from_values()
assert req.range is None
- req = wrappers.Request.from_values(headers={'Range': 'bytes=0-499'})
+ req = wrappers.Request.from_values(headers={"Range": "bytes=0-499"})
assert req.range.ranges == [(0, 500)]
resp = wrappers.Response()
resp.content_range = req.range.make_content_range(1000)
- assert resp.content_range.units == 'bytes'
+ assert resp.content_range.units == "bytes"
assert resp.content_range.start == 0
assert resp.content_range.stop == 500
assert resp.content_range.length == 1000
- assert resp.headers['Content-Range'] == 'bytes 0-499/1000'
+ assert resp.headers["Content-Range"] == "bytes 0-499/1000"
resp.content_range.unset()
- assert 'Content-Range' not in resp.headers
+ assert "Content-Range" not in resp.headers
- resp.headers['Content-Range'] = 'bytes 0-499/1000'
- assert resp.content_range.units == 'bytes'
+ resp.headers["Content-Range"] = "bytes 0-499/1000"
+ assert resp.content_range.units == "bytes"
assert resp.content_range.start == 0
assert resp.content_range.stop == 500
assert resp.content_range.length == 1000
def test_auto_content_length():
- resp = wrappers.Response('Hello World!')
+ resp = wrappers.Response("Hello World!")
assert resp.content_length == 12
- resp = wrappers.Response(['Hello World!'])
+ resp = wrappers.Response(["Hello World!"])
assert resp.content_length is None
- assert resp.get_wsgi_headers({})['Content-Length'] == '12'
+ assert resp.get_wsgi_headers({})["Content-Length"] == "12"
def test_stream_content_length():
resp = wrappers.Response()
- resp.stream.writelines(['foo', 'bar', 'baz'])
- assert resp.get_wsgi_headers({})['Content-Length'] == '9'
+ resp.stream.writelines(["foo", "bar", "baz"])
+ assert resp.get_wsgi_headers({})["Content-Length"] == "9"
resp = wrappers.Response()
- resp.make_conditional({'REQUEST_METHOD': 'GET'})
- resp.stream.writelines(['foo', 'bar', 'baz'])
- assert resp.get_wsgi_headers({})['Content-Length'] == '9'
+ resp.make_conditional({"REQUEST_METHOD": "GET"})
+ resp.stream.writelines(["foo", "bar", "baz"])
+ assert resp.get_wsgi_headers({})["Content-Length"] == "9"
- resp = wrappers.Response('foo')
- resp.stream.writelines(['bar', 'baz'])
- assert resp.get_wsgi_headers({})['Content-Length'] == '9'
+ resp = wrappers.Response("foo")
+ resp.stream.writelines(["bar", "baz"])
+ assert resp.get_wsgi_headers({})["Content-Length"] == "9"
def test_disabled_auto_content_length():
class MyResponse(wrappers.Response):
automatically_set_content_length = False
- resp = MyResponse('Hello World!')
+
+ resp = MyResponse("Hello World!")
assert resp.content_length is None
- resp = MyResponse(['Hello World!'])
+ resp = MyResponse(["Hello World!"])
assert resp.content_length is None
- assert 'Content-Length' not in resp.get_wsgi_headers({})
+ assert "Content-Length" not in resp.get_wsgi_headers({})
resp = MyResponse()
- resp.make_conditional({
- 'REQUEST_METHOD': 'GET'
- })
+ resp.make_conditional({"REQUEST_METHOD": "GET"})
assert resp.content_length is None
- assert 'Content-Length' not in resp.get_wsgi_headers({})
-
-
-@pytest.mark.parametrize(('auto', 'location', 'expect'), (
- (False, '/test', '/test'),
- (True, '/test', 'http://localhost/test'),
- (True, 'test', 'http://localhost/a/b/test'),
- (True, './test', 'http://localhost/a/b/test'),
- (True, '../test', 'http://localhost/a/test'),
-))
+ assert "Content-Length" not in resp.get_wsgi_headers({})
+
+
+@pytest.mark.parametrize(
+ ("auto", "location", "expect"),
+ (
+ (False, "/test", "/test"),
+ (True, "/test", "http://localhost/test"),
+ (True, "test", "http://localhost/a/b/test"),
+ (True, "./test", "http://localhost/a/b/test"),
+ (True, "../test", "http://localhost/a/test"),
+ ),
+)
def test_location_header_autocorrect(monkeypatch, auto, location, expect):
- monkeypatch.setattr(wrappers.Response, 'autocorrect_location_header', auto)
- env = create_environ('/a/b/c')
- resp = wrappers.Response('Hello World!')
- resp.headers['Location'] = location
- assert resp.get_wsgi_headers(env)['Location'] == expect
+ monkeypatch.setattr(wrappers.Response, "autocorrect_location_header", auto)
+ env = create_environ("/a/b/c")
+ resp = wrappers.Response("Hello World!")
+ resp.headers["Location"] = location
+ assert resp.get_wsgi_headers(env)["Location"] == expect
def test_204_and_1XX_response_has_no_content_length():
@@ -1181,39 +1339,39 @@ def test_204_and_1XX_response_has_no_content_length():
assert response.content_length is None
headers = response.get_wsgi_headers(create_environ())
- assert 'Content-Length' not in headers
+ assert "Content-Length" not in headers
response = wrappers.Response(status=100)
assert response.content_length is None
headers = response.get_wsgi_headers(create_environ())
- assert 'Content-Length' not in headers
+ assert "Content-Length" not in headers
def test_malformed_204_response_has_no_content_length():
# flask-restful can generate a malformed response when doing `return '', 204`
response = wrappers.Response(status=204)
- response.set_data(b'test')
+ response.set_data(b"test")
assert response.content_length == 4
env = create_environ()
app_iter, status, headers = response.get_wsgi_response(env)
- assert status == '204 NO CONTENT'
- assert 'Content-Length' not in headers
- assert b''.join(app_iter) == b'' # ensure data will not be sent
+ assert status == "204 NO CONTENT"
+ assert "Content-Length" not in headers
+ assert b"".join(app_iter) == b"" # ensure data will not be sent
def test_modified_url_encoding():
class ModifiedRequest(wrappers.Request):
- url_charset = 'euc-kr'
+ url_charset = "euc-kr"
- req = ModifiedRequest.from_values(u'/?foo=정상처리'.encode('euc-kr'))
- strict_eq(req.args['foo'], u'정상처리')
+ req = ModifiedRequest.from_values(u"/?foo=정상처리".encode("euc-kr"))
+ strict_eq(req.args["foo"], u"정상처리")
def test_request_method_case_sensitivity():
- req = wrappers.Request({'REQUEST_METHOD': 'get'})
- assert req.method == 'GET'
+ req = wrappers.Request({"REQUEST_METHOD": "get"})
+ assert req.method == "GET"
def test_is_xhr_warning():
@@ -1225,7 +1383,7 @@ def test_is_xhr_warning():
def test_write_length():
response = wrappers.Response()
- length = response.stream.write(b'bar')
+ length = response.stream.write(b"bar")
assert length == 3
@@ -1233,13 +1391,13 @@ def test_stream_zip():
import zipfile
response = wrappers.Response()
- with contextlib.closing(zipfile.ZipFile(response.stream, mode='w')) as z:
+ with contextlib.closing(zipfile.ZipFile(response.stream, mode="w")) as z:
z.writestr("foo", b"bar")
buffer = BytesIO(response.get_data())
- with contextlib.closing(zipfile.ZipFile(buffer, mode='r')) as z:
- assert z.namelist() == ['foo']
- assert z.read('foo') == b'bar'
+ with contextlib.closing(zipfile.ZipFile(buffer, mode="r")) as z:
+ assert z.namelist() == ["foo"]
+ assert z.read("foo") == b"bar"
class TestSetCookie(object):
@@ -1247,49 +1405,103 @@ class TestSetCookie(object):
def test_secure(self):
response = wrappers.BaseResponse()
- response.set_cookie('foo', value='bar', max_age=60, expires=0,
- path='/blub', domain='example.org', secure=True,
- samesite=None)
- strict_eq(response.headers.to_wsgi_list(), [
- ('Content-Type', 'text/plain; charset=utf-8'),
- ('Set-Cookie', 'foo=bar; Domain=example.org; Expires=Thu, '
- '01-Jan-1970 00:00:00 GMT; Max-Age=60; Secure; Path=/blub')
- ])
+ response.set_cookie(
+ "foo",
+ value="bar",
+ max_age=60,
+ expires=0,
+ path="/blub",
+ domain="example.org",
+ secure=True,
+ samesite=None,
+ )
+ strict_eq(
+ response.headers.to_wsgi_list(),
+ [
+ ("Content-Type", "text/plain; charset=utf-8"),
+ (
+ "Set-Cookie",
+ "foo=bar; Domain=example.org; Expires=Thu, "
+ "01-Jan-1970 00:00:00 GMT; Max-Age=60; Secure; Path=/blub",
+ ),
+ ],
+ )
def test_httponly(self):
response = wrappers.BaseResponse()
- response.set_cookie('foo', value='bar', max_age=60, expires=0,
- path='/blub', domain='example.org', secure=False,
- httponly=True, samesite=None)
- strict_eq(response.headers.to_wsgi_list(), [
- ('Content-Type', 'text/plain; charset=utf-8'),
- ('Set-Cookie', 'foo=bar; Domain=example.org; Expires=Thu, '
- '01-Jan-1970 00:00:00 GMT; Max-Age=60; HttpOnly; Path=/blub')
- ])
+ response.set_cookie(
+ "foo",
+ value="bar",
+ max_age=60,
+ expires=0,
+ path="/blub",
+ domain="example.org",
+ secure=False,
+ httponly=True,
+ samesite=None,
+ )
+ strict_eq(
+ response.headers.to_wsgi_list(),
+ [
+ ("Content-Type", "text/plain; charset=utf-8"),
+ (
+ "Set-Cookie",
+ "foo=bar; Domain=example.org; Expires=Thu, "
+ "01-Jan-1970 00:00:00 GMT; Max-Age=60; HttpOnly; Path=/blub",
+ ),
+ ],
+ )
def test_secure_and_httponly(self):
response = wrappers.BaseResponse()
- response.set_cookie('foo', value='bar', max_age=60, expires=0,
- path='/blub', domain='example.org', secure=True,
- httponly=True, samesite=None)
- strict_eq(response.headers.to_wsgi_list(), [
- ('Content-Type', 'text/plain; charset=utf-8'),
- ('Set-Cookie', 'foo=bar; Domain=example.org; Expires=Thu, '
- '01-Jan-1970 00:00:00 GMT; Max-Age=60; Secure; HttpOnly; '
- 'Path=/blub')
- ])
+ response.set_cookie(
+ "foo",
+ value="bar",
+ max_age=60,
+ expires=0,
+ path="/blub",
+ domain="example.org",
+ secure=True,
+ httponly=True,
+ samesite=None,
+ )
+ strict_eq(
+ response.headers.to_wsgi_list(),
+ [
+ ("Content-Type", "text/plain; charset=utf-8"),
+ (
+ "Set-Cookie",
+ "foo=bar; Domain=example.org; Expires=Thu, "
+ "01-Jan-1970 00:00:00 GMT; Max-Age=60; Secure; HttpOnly; "
+ "Path=/blub",
+ ),
+ ],
+ )
def test_samesite(self):
response = wrappers.BaseResponse()
- response.set_cookie('foo', value='bar', max_age=60, expires=0,
- path='/blub', domain='example.org', secure=False,
- samesite='strict')
- strict_eq(response.headers.to_wsgi_list(), [
- ('Content-Type', 'text/plain; charset=utf-8'),
- ('Set-Cookie', 'foo=bar; Domain=example.org; Expires=Thu, '
- '01-Jan-1970 00:00:00 GMT; Max-Age=60; Path=/blub; '
- 'SameSite=Strict')
- ])
+ response.set_cookie(
+ "foo",
+ value="bar",
+ max_age=60,
+ expires=0,
+ path="/blub",
+ domain="example.org",
+ secure=False,
+ samesite="strict",
+ )
+ strict_eq(
+ response.headers.to_wsgi_list(),
+ [
+ ("Content-Type", "text/plain; charset=utf-8"),
+ (
+ "Set-Cookie",
+ "foo=bar; Domain=example.org; Expires=Thu, "
+ "01-Jan-1970 00:00:00 GMT; Max-Age=60; Path=/blub; "
+ "SameSite=Strict",
+ ),
+ ],
+ )
class TestJSONMixin(object):
@@ -1308,8 +1520,7 @@ class TestJSONMixin(object):
def test_response(self):
value = {u"ä": "b"}
response = self.Response(
- response=json.dumps(value),
- content_type="application/json",
+ response=json.dumps(value), content_type="application/json"
)
assert response.json == value
@@ -1321,8 +1532,7 @@ class TestJSONMixin(object):
def test_silent(self):
request = self.Request.from_values(
- data=b'{"a":}',
- content_type="application/json",
+ data=b'{"a":}', content_type="application/json"
)
assert request.get_json(silent=True) is None
diff --git a/tests/test_wsgi.py b/tests/test_wsgi.py
index 30c212a7..f99aa32c 100644
--- a/tests/test_wsgi.py
+++ b/tests/test_wsgi.py
@@ -14,208 +14,212 @@ import os
import pytest
-from tests import strict_eq
+from . import strict_eq
from werkzeug import wsgi
-from werkzeug._compat import BytesIO, NativeStringIO, StringIO
-from werkzeug.exceptions import BadRequest, ClientDisconnected
-from werkzeug.test import Client, create_environ, run_wsgi_app
+from werkzeug._compat import BytesIO
+from werkzeug._compat import NativeStringIO
+from werkzeug._compat import StringIO
+from werkzeug.exceptions import BadRequest
+from werkzeug.exceptions import ClientDisconnected
+from werkzeug.test import Client
+from werkzeug.test import create_environ
+from werkzeug.test import run_wsgi_app
from werkzeug.wrappers import BaseResponse
-from werkzeug.wsgi import _RangeWrapper, ClosingIterator, wrap_file
-
-
-@pytest.mark.parametrize(('environ', 'expect'), (
- pytest.param({
- 'HTTP_HOST': 'spam',
- }, 'spam', id='host'),
- pytest.param({
- 'HTTP_HOST': 'spam:80',
- }, 'spam', id='host, strip http port'),
- pytest.param({
- 'wsgi.url_scheme': 'https',
- 'HTTP_HOST': 'spam:443',
- }, 'spam', id='host, strip https port'),
- pytest.param({
- 'HTTP_HOST': 'spam:8080',
- }, 'spam:8080', id='host, custom port'),
- pytest.param({
- 'HTTP_HOST': 'spam',
- 'SERVER_NAME': 'eggs',
- 'SERVER_PORT': '80',
- }, 'spam', id='prefer host'),
- pytest.param({
- 'SERVER_NAME': 'eggs',
- 'SERVER_PORT': '80'
- }, 'eggs', id='name, ignore http port'),
- pytest.param({
- 'wsgi.url_scheme': 'https',
- 'SERVER_NAME': 'eggs',
- 'SERVER_PORT': '443'
- }, 'eggs', id='name, ignore https port'),
- pytest.param({
- 'SERVER_NAME': 'eggs',
- 'SERVER_PORT': '8080'
- }, 'eggs:8080', id='name, custom port'),
- pytest.param({
- 'HTTP_HOST': 'ham',
- 'HTTP_X_FORWARDED_HOST': 'eggs'
- }, 'ham', id='ignore x-forwarded-host'),
-))
+from werkzeug.wsgi import _RangeWrapper
+from werkzeug.wsgi import ClosingIterator
+from werkzeug.wsgi import wrap_file
+
+
+@pytest.mark.parametrize(
+ ("environ", "expect"),
+ (
+ pytest.param({"HTTP_HOST": "spam"}, "spam", id="host"),
+ pytest.param({"HTTP_HOST": "spam:80"}, "spam", id="host, strip http port"),
+ pytest.param(
+ {"wsgi.url_scheme": "https", "HTTP_HOST": "spam:443"},
+ "spam",
+ id="host, strip https port",
+ ),
+ pytest.param({"HTTP_HOST": "spam:8080"}, "spam:8080", id="host, custom port"),
+ pytest.param(
+ {"HTTP_HOST": "spam", "SERVER_NAME": "eggs", "SERVER_PORT": "80"},
+ "spam",
+ id="prefer host",
+ ),
+ pytest.param(
+ {"SERVER_NAME": "eggs", "SERVER_PORT": "80"},
+ "eggs",
+ id="name, ignore http port",
+ ),
+ pytest.param(
+ {"wsgi.url_scheme": "https", "SERVER_NAME": "eggs", "SERVER_PORT": "443"},
+ "eggs",
+ id="name, ignore https port",
+ ),
+ pytest.param(
+ {"SERVER_NAME": "eggs", "SERVER_PORT": "8080"},
+ "eggs:8080",
+ id="name, custom port",
+ ),
+ pytest.param(
+ {"HTTP_HOST": "ham", "HTTP_X_FORWARDED_HOST": "eggs"},
+ "ham",
+ id="ignore x-forwarded-host",
+ ),
+ ),
+)
def test_get_host(environ, expect):
- environ.setdefault('wsgi.url_scheme', 'http')
+ environ.setdefault("wsgi.url_scheme", "http")
assert wsgi.get_host(environ) == expect
def test_get_host_validate_trusted_hosts():
- env = {'SERVER_NAME': 'example.org', 'SERVER_PORT': '80',
- 'wsgi.url_scheme': 'http'}
- assert wsgi.get_host(env, trusted_hosts=['.example.org']) == 'example.org'
- pytest.raises(BadRequest, wsgi.get_host, env,
- trusted_hosts=['example.com'])
- env['SERVER_PORT'] = '8080'
- assert wsgi.get_host(env, trusted_hosts=['.example.org:8080']) == 'example.org:8080'
- pytest.raises(BadRequest, wsgi.get_host, env,
- trusted_hosts=['.example.com'])
- env = {'HTTP_HOST': 'example.org', 'wsgi.url_scheme': 'http'}
- assert wsgi.get_host(env, trusted_hosts=['.example.org']) == 'example.org'
- pytest.raises(BadRequest, wsgi.get_host, env,
- trusted_hosts=['example.com'])
+ env = {"SERVER_NAME": "example.org", "SERVER_PORT": "80", "wsgi.url_scheme": "http"}
+ assert wsgi.get_host(env, trusted_hosts=[".example.org"]) == "example.org"
+ pytest.raises(BadRequest, wsgi.get_host, env, trusted_hosts=["example.com"])
+ env["SERVER_PORT"] = "8080"
+ assert wsgi.get_host(env, trusted_hosts=[".example.org:8080"]) == "example.org:8080"
+ pytest.raises(BadRequest, wsgi.get_host, env, trusted_hosts=[".example.com"])
+ env = {"HTTP_HOST": "example.org", "wsgi.url_scheme": "http"}
+ assert wsgi.get_host(env, trusted_hosts=[".example.org"]) == "example.org"
+ pytest.raises(BadRequest, wsgi.get_host, env, trusted_hosts=["example.com"])
def test_responder():
def foo(environ, start_response):
- return BaseResponse(b'Test')
+ return BaseResponse(b"Test")
+
client = Client(wsgi.responder(foo), BaseResponse)
- response = client.get('/')
+ response = client.get("/")
assert response.status_code == 200
- assert response.data == b'Test'
+ assert response.data == b"Test"
def test_pop_path_info():
- original_env = {'SCRIPT_NAME': '/foo', 'PATH_INFO': '/a/b///c'}
+ original_env = {"SCRIPT_NAME": "/foo", "PATH_INFO": "/a/b///c"}
# regular path info popping
def assert_tuple(script_name, path_info):
- assert env.get('SCRIPT_NAME') == script_name
- assert env.get('PATH_INFO') == path_info
+ assert env.get("SCRIPT_NAME") == script_name
+ assert env.get("PATH_INFO") == path_info
+
env = original_env.copy()
- pop = lambda: wsgi.pop_path_info(env)
-
- assert_tuple('/foo', '/a/b///c')
- assert pop() == 'a'
- assert_tuple('/foo/a', '/b///c')
- assert pop() == 'b'
- assert_tuple('/foo/a/b', '///c')
- assert pop() == 'c'
- assert_tuple('/foo/a/b///c', '')
+
+ def pop():
+ return wsgi.pop_path_info(env)
+
+ assert_tuple("/foo", "/a/b///c")
+ assert pop() == "a"
+ assert_tuple("/foo/a", "/b///c")
+ assert pop() == "b"
+ assert_tuple("/foo/a/b", "///c")
+ assert pop() == "c"
+ assert_tuple("/foo/a/b///c", "")
assert pop() is None
def test_peek_path_info():
- env = {
- 'SCRIPT_NAME': '/foo',
- 'PATH_INFO': '/aaa/b///c'
- }
+ env = {"SCRIPT_NAME": "/foo", "PATH_INFO": "/aaa/b///c"}
- assert wsgi.peek_path_info(env) == 'aaa'
- assert wsgi.peek_path_info(env) == 'aaa'
- assert wsgi.peek_path_info(env, charset=None) == b'aaa'
- assert wsgi.peek_path_info(env, charset=None) == b'aaa'
+ assert wsgi.peek_path_info(env) == "aaa"
+ assert wsgi.peek_path_info(env) == "aaa"
+ assert wsgi.peek_path_info(env, charset=None) == b"aaa"
+ assert wsgi.peek_path_info(env, charset=None) == b"aaa"
def test_path_info_and_script_name_fetching():
- env = create_environ(u'/\N{SNOWMAN}', u'http://example.com/\N{COMET}/')
- assert wsgi.get_path_info(env) == u'/\N{SNOWMAN}'
- assert wsgi.get_path_info(env, charset=None) == u'/\N{SNOWMAN}'.encode('utf-8')
- assert wsgi.get_script_name(env) == u'/\N{COMET}'
- assert wsgi.get_script_name(env, charset=None) == u'/\N{COMET}'.encode('utf-8')
+ env = create_environ(u"/\N{SNOWMAN}", u"http://example.com/\N{COMET}/")
+ assert wsgi.get_path_info(env) == u"/\N{SNOWMAN}"
+ assert wsgi.get_path_info(env, charset=None) == u"/\N{SNOWMAN}".encode("utf-8")
+ assert wsgi.get_script_name(env) == u"/\N{COMET}"
+ assert wsgi.get_script_name(env, charset=None) == u"/\N{COMET}".encode("utf-8")
def test_query_string_fetching():
- env = create_environ(u'/?\N{SNOWMAN}=\N{COMET}')
+ env = create_environ(u"/?\N{SNOWMAN}=\N{COMET}")
qs = wsgi.get_query_string(env)
- strict_eq(qs, '%E2%98%83=%E2%98%84')
+ strict_eq(qs, "%E2%98%83=%E2%98%84")
def test_limited_stream():
class RaisingLimitedStream(wsgi.LimitedStream):
-
def on_exhausted(self):
- raise BadRequest('input stream exhausted')
+ raise BadRequest("input stream exhausted")
- io = BytesIO(b'123456')
+ io = BytesIO(b"123456")
stream = RaisingLimitedStream(io, 3)
- strict_eq(stream.read(), b'123')
+ strict_eq(stream.read(), b"123")
pytest.raises(BadRequest, stream.read)
- io = BytesIO(b'123456')
+ io = BytesIO(b"123456")
stream = RaisingLimitedStream(io, 3)
strict_eq(stream.tell(), 0)
- strict_eq(stream.read(1), b'1')
+ strict_eq(stream.read(1), b"1")
strict_eq(stream.tell(), 1)
- strict_eq(stream.read(1), b'2')
+ strict_eq(stream.read(1), b"2")
strict_eq(stream.tell(), 2)
- strict_eq(stream.read(1), b'3')
+ strict_eq(stream.read(1), b"3")
strict_eq(stream.tell(), 3)
pytest.raises(BadRequest, stream.read)
- io = BytesIO(b'123456\nabcdefg')
+ io = BytesIO(b"123456\nabcdefg")
stream = wsgi.LimitedStream(io, 9)
- strict_eq(stream.readline(), b'123456\n')
- strict_eq(stream.readline(), b'ab')
+ strict_eq(stream.readline(), b"123456\n")
+ strict_eq(stream.readline(), b"ab")
- io = BytesIO(b'123456\nabcdefg')
+ io = BytesIO(b"123456\nabcdefg")
stream = wsgi.LimitedStream(io, 9)
- strict_eq(stream.readlines(), [b'123456\n', b'ab'])
+ strict_eq(stream.readlines(), [b"123456\n", b"ab"])
- io = BytesIO(b'123456\nabcdefg')
+ io = BytesIO(b"123456\nabcdefg")
stream = wsgi.LimitedStream(io, 9)
- strict_eq(stream.readlines(2), [b'12'])
- strict_eq(stream.readlines(2), [b'34'])
- strict_eq(stream.readlines(), [b'56\n', b'ab'])
+ strict_eq(stream.readlines(2), [b"12"])
+ strict_eq(stream.readlines(2), [b"34"])
+ strict_eq(stream.readlines(), [b"56\n", b"ab"])
- io = BytesIO(b'123456\nabcdefg')
+ io = BytesIO(b"123456\nabcdefg")
stream = wsgi.LimitedStream(io, 9)
- strict_eq(stream.readline(100), b'123456\n')
+ strict_eq(stream.readline(100), b"123456\n")
- io = BytesIO(b'123456\nabcdefg')
+ io = BytesIO(b"123456\nabcdefg")
stream = wsgi.LimitedStream(io, 9)
- strict_eq(stream.readlines(100), [b'123456\n', b'ab'])
+ strict_eq(stream.readlines(100), [b"123456\n", b"ab"])
- io = BytesIO(b'123456')
+ io = BytesIO(b"123456")
stream = wsgi.LimitedStream(io, 3)
- strict_eq(stream.read(1), b'1')
- strict_eq(stream.read(1), b'2')
- strict_eq(stream.read(), b'3')
- strict_eq(stream.read(), b'')
+ strict_eq(stream.read(1), b"1")
+ strict_eq(stream.read(1), b"2")
+ strict_eq(stream.read(), b"3")
+ strict_eq(stream.read(), b"")
- io = BytesIO(b'123456')
+ io = BytesIO(b"123456")
stream = wsgi.LimitedStream(io, 3)
- strict_eq(stream.read(-1), b'123')
+ strict_eq(stream.read(-1), b"123")
- io = BytesIO(b'123456')
+ io = BytesIO(b"123456")
stream = wsgi.LimitedStream(io, 0)
- strict_eq(stream.read(-1), b'')
+ strict_eq(stream.read(-1), b"")
- io = StringIO(u'123456')
+ io = StringIO(u"123456")
stream = wsgi.LimitedStream(io, 0)
- strict_eq(stream.read(-1), u'')
+ strict_eq(stream.read(-1), u"")
- io = StringIO(u'123\n456\n')
+ io = StringIO(u"123\n456\n")
stream = wsgi.LimitedStream(io, 8)
- strict_eq(list(stream), [u'123\n', u'456\n'])
+ strict_eq(list(stream), [u"123\n", u"456\n"])
def test_limited_stream_json_load():
stream = wsgi.LimitedStream(BytesIO(b'{"hello": "test"}'), 17)
# flask.json adapts bytes to text with TextIOWrapper
# this expects stream.readable() to exist and return true
- stream = io.TextIOWrapper(io.BufferedReader(stream), 'UTF-8')
+ stream = io.TextIOWrapper(io.BufferedReader(stream), "UTF-8")
data = json.load(stream)
- assert data == {'hello': 'test'}
+ assert data == {"hello": "test"}
def test_limited_stream_disconnection():
- io = BytesIO(b'A bit of content')
+ io = BytesIO(b"A bit of content")
# disconnect detection on out of bytes
stream = wsgi.LimitedStream(io, 255)
@@ -223,7 +227,7 @@ def test_limited_stream_disconnection():
stream.read()
# disconnect detection because file close
- io = BytesIO(b'x' * 255)
+ io = BytesIO(b"x" * 255)
io.close()
stream = wsgi.LimitedStream(io, 255)
with pytest.raises(ClientDisconnected):
@@ -231,46 +235,58 @@ def test_limited_stream_disconnection():
def test_path_info_extraction():
- x = wsgi.extract_path_info('http://example.com/app', '/app/hello')
- assert x == u'/hello'
- x = wsgi.extract_path_info('http://example.com/app',
- 'https://example.com/app/hello')
- assert x == u'/hello'
- x = wsgi.extract_path_info('http://example.com/app/',
- 'https://example.com/app/hello')
- assert x == u'/hello'
- x = wsgi.extract_path_info('http://example.com/app/',
- 'https://example.com/app')
- assert x == u'/'
- x = wsgi.extract_path_info(u'http://☃.net/', u'/fööbär')
- assert x == u'/fööbär'
- x = wsgi.extract_path_info(u'http://☃.net/x', u'http://☃.net/x/fööbär')
- assert x == u'/fööbär'
-
- env = create_environ(u'/fööbär', u'http://☃.net/x/')
- x = wsgi.extract_path_info(env, u'http://☃.net/x/fööbär')
- assert x == u'/fööbär'
-
- x = wsgi.extract_path_info('http://example.com/app/',
- 'https://example.com/a/hello')
+ x = wsgi.extract_path_info("http://example.com/app", "/app/hello")
+ assert x == u"/hello"
+ x = wsgi.extract_path_info(
+ "http://example.com/app", "https://example.com/app/hello"
+ )
+ assert x == u"/hello"
+ x = wsgi.extract_path_info(
+ "http://example.com/app/", "https://example.com/app/hello"
+ )
+ assert x == u"/hello"
+ x = wsgi.extract_path_info("http://example.com/app/", "https://example.com/app")
+ assert x == u"/"
+ x = wsgi.extract_path_info(u"http://☃.net/", u"/fööbär")
+ assert x == u"/fööbär"
+ x = wsgi.extract_path_info(u"http://☃.net/x", u"http://☃.net/x/fööbär")
+ assert x == u"/fööbär"
+
+ env = create_environ(u"/fööbär", u"http://☃.net/x/")
+ x = wsgi.extract_path_info(env, u"http://☃.net/x/fööbär")
+ assert x == u"/fööbär"
+
+ x = wsgi.extract_path_info("http://example.com/app/", "https://example.com/a/hello")
assert x is None
- x = wsgi.extract_path_info('http://example.com/app/',
- 'https://example.com/app/hello',
- collapse_http_schemes=False)
+ x = wsgi.extract_path_info(
+ "http://example.com/app/",
+ "https://example.com/app/hello",
+ collapse_http_schemes=False,
+ )
assert x is None
def test_get_host_fallback():
- assert wsgi.get_host({
- 'SERVER_NAME': 'foobar.example.com',
- 'wsgi.url_scheme': 'http',
- 'SERVER_PORT': '80'
- }) == 'foobar.example.com'
- assert wsgi.get_host({
- 'SERVER_NAME': 'foobar.example.com',
- 'wsgi.url_scheme': 'http',
- 'SERVER_PORT': '81'
- }) == 'foobar.example.com:81'
+ assert (
+ wsgi.get_host(
+ {
+ "SERVER_NAME": "foobar.example.com",
+ "wsgi.url_scheme": "http",
+ "SERVER_PORT": "80",
+ }
+ )
+ == "foobar.example.com"
+ )
+ assert (
+ wsgi.get_host(
+ {
+ "SERVER_NAME": "foobar.example.com",
+ "wsgi.url_scheme": "http",
+ "SERVER_PORT": "81",
+ }
+ )
+ == "foobar.example.com:81"
+ )
def test_get_current_url_unicode():
@@ -289,150 +305,172 @@ def test_get_current_url_invalid_utf8():
def test_multi_part_line_breaks():
- data = 'abcdef\r\nghijkl\r\nmnopqrstuvwxyz\r\nABCDEFGHIJK'
+ data = "abcdef\r\nghijkl\r\nmnopqrstuvwxyz\r\nABCDEFGHIJK"
test_stream = NativeStringIO(data)
- lines = list(wsgi.make_line_iter(test_stream, limit=len(data),
- buffer_size=16))
- assert lines == ['abcdef\r\n', 'ghijkl\r\n', 'mnopqrstuvwxyz\r\n',
- 'ABCDEFGHIJK']
+ lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=16))
+ assert lines == ["abcdef\r\n", "ghijkl\r\n", "mnopqrstuvwxyz\r\n", "ABCDEFGHIJK"]
- data = 'abc\r\nThis line is broken by the buffer length.' \
- '\r\nFoo bar baz'
+ data = "abc\r\nThis line is broken by the buffer length.\r\nFoo bar baz"
test_stream = NativeStringIO(data)
- lines = list(wsgi.make_line_iter(test_stream, limit=len(data),
- buffer_size=24))
- assert lines == ['abc\r\n', 'This line is broken by the buffer '
- 'length.\r\n', 'Foo bar baz']
+ lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=24))
+ assert lines == [
+ "abc\r\n",
+ "This line is broken by the buffer length.\r\n",
+ "Foo bar baz",
+ ]
def test_multi_part_line_breaks_bytes():
- data = b'abcdef\r\nghijkl\r\nmnopqrstuvwxyz\r\nABCDEFGHIJK'
+ data = b"abcdef\r\nghijkl\r\nmnopqrstuvwxyz\r\nABCDEFGHIJK"
test_stream = BytesIO(data)
- lines = list(wsgi.make_line_iter(test_stream, limit=len(data),
- buffer_size=16))
- assert lines == [b'abcdef\r\n', b'ghijkl\r\n', b'mnopqrstuvwxyz\r\n',
- b'ABCDEFGHIJK']
-
- data = b'abc\r\nThis line is broken by the buffer length.' \
- b'\r\nFoo bar baz'
+ lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=16))
+ assert lines == [
+ b"abcdef\r\n",
+ b"ghijkl\r\n",
+ b"mnopqrstuvwxyz\r\n",
+ b"ABCDEFGHIJK",
+ ]
+
+ data = b"abc\r\nThis line is broken by the buffer length." b"\r\nFoo bar baz"
test_stream = BytesIO(data)
- lines = list(wsgi.make_line_iter(test_stream, limit=len(data),
- buffer_size=24))
- assert lines == [b'abc\r\n', b'This line is broken by the buffer '
- b'length.\r\n', b'Foo bar baz']
+ lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=24))
+ assert lines == [
+ b"abc\r\n",
+ b"This line is broken by the buffer " b"length.\r\n",
+ b"Foo bar baz",
+ ]
def test_multi_part_line_breaks_problematic():
- data = 'abc\rdef\r\nghi'
- for x in range(1, 10):
+ data = "abc\rdef\r\nghi"
+ for _ in range(1, 10):
test_stream = NativeStringIO(data)
- lines = list(wsgi.make_line_iter(test_stream, limit=len(data),
- buffer_size=4))
- assert lines == ['abc\r', 'def\r\n', 'ghi']
+ lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=4))
+ assert lines == ["abc\r", "def\r\n", "ghi"]
def test_iter_functions_support_iterators():
- data = ['abcdef\r\nghi', 'jkl\r\nmnopqrstuvwxyz\r', '\nABCDEFGHIJK']
+ data = ["abcdef\r\nghi", "jkl\r\nmnopqrstuvwxyz\r", "\nABCDEFGHIJK"]
lines = list(wsgi.make_line_iter(data))
- assert lines == ['abcdef\r\n', 'ghijkl\r\n', 'mnopqrstuvwxyz\r\n',
- 'ABCDEFGHIJK']
+ assert lines == ["abcdef\r\n", "ghijkl\r\n", "mnopqrstuvwxyz\r\n", "ABCDEFGHIJK"]
def test_make_chunk_iter():
- data = [u'abcdefXghi', u'jklXmnopqrstuvwxyzX', u'ABCDEFGHIJK']
- rv = list(wsgi.make_chunk_iter(data, 'X'))
- assert rv == [u'abcdef', u'ghijkl', u'mnopqrstuvwxyz', u'ABCDEFGHIJK']
+ data = [u"abcdefXghi", u"jklXmnopqrstuvwxyzX", u"ABCDEFGHIJK"]
+ rv = list(wsgi.make_chunk_iter(data, "X"))
+ assert rv == [u"abcdef", u"ghijkl", u"mnopqrstuvwxyz", u"ABCDEFGHIJK"]
- data = u'abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK'
+ data = u"abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK"
test_stream = StringIO(data)
- rv = list(wsgi.make_chunk_iter(test_stream, 'X', limit=len(data),
- buffer_size=4))
- assert rv == [u'abcdef', u'ghijkl', u'mnopqrstuvwxyz', u'ABCDEFGHIJK']
+ rv = list(wsgi.make_chunk_iter(test_stream, "X", limit=len(data), buffer_size=4))
+ assert rv == [u"abcdef", u"ghijkl", u"mnopqrstuvwxyz", u"ABCDEFGHIJK"]
def test_make_chunk_iter_bytes():
- data = [b'abcdefXghi', b'jklXmnopqrstuvwxyzX', b'ABCDEFGHIJK']
- rv = list(wsgi.make_chunk_iter(data, 'X'))
- assert rv == [b'abcdef', b'ghijkl', b'mnopqrstuvwxyz', b'ABCDEFGHIJK']
+ data = [b"abcdefXghi", b"jklXmnopqrstuvwxyzX", b"ABCDEFGHIJK"]
+ rv = list(wsgi.make_chunk_iter(data, "X"))
+ assert rv == [b"abcdef", b"ghijkl", b"mnopqrstuvwxyz", b"ABCDEFGHIJK"]
- data = b'abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK'
+ data = b"abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK"
test_stream = BytesIO(data)
- rv = list(wsgi.make_chunk_iter(test_stream, 'X', limit=len(data),
- buffer_size=4))
- assert rv == [b'abcdef', b'ghijkl', b'mnopqrstuvwxyz', b'ABCDEFGHIJK']
+ rv = list(wsgi.make_chunk_iter(test_stream, "X", limit=len(data), buffer_size=4))
+ assert rv == [b"abcdef", b"ghijkl", b"mnopqrstuvwxyz", b"ABCDEFGHIJK"]
- data = b'abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK'
+ data = b"abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK"
test_stream = BytesIO(data)
- rv = list(wsgi.make_chunk_iter(test_stream, 'X', limit=len(data),
- buffer_size=4, cap_at_buffer=True))
- assert rv == [b'abcd', b'ef', b'ghij', b'kl', b'mnop', b'qrst', b'uvwx',
- b'yz', b'ABCD', b'EFGH', b'IJK']
+ rv = list(
+ wsgi.make_chunk_iter(
+ test_stream, "X", limit=len(data), buffer_size=4, cap_at_buffer=True
+ )
+ )
+ assert rv == [
+ b"abcd",
+ b"ef",
+ b"ghij",
+ b"kl",
+ b"mnop",
+ b"qrst",
+ b"uvwx",
+ b"yz",
+ b"ABCD",
+ b"EFGH",
+ b"IJK",
+ ]
def test_lines_longer_buffer_size():
- data = '1234567890\n1234567890\n'
+ data = "1234567890\n1234567890\n"
for bufsize in range(1, 15):
- lines = list(wsgi.make_line_iter(NativeStringIO(data), limit=len(data),
- buffer_size=4))
- assert lines == ['1234567890\n', '1234567890\n']
+ lines = list(
+ wsgi.make_line_iter(
+ NativeStringIO(data), limit=len(data), buffer_size=bufsize
+ )
+ )
+ assert lines == ["1234567890\n", "1234567890\n"]
def test_lines_longer_buffer_size_cap():
- data = '1234567890\n1234567890\n'
+ data = "1234567890\n1234567890\n"
for bufsize in range(1, 15):
- lines = list(wsgi.make_line_iter(NativeStringIO(data), limit=len(data),
- buffer_size=4, cap_at_buffer=True))
- assert lines == ['1234', '5678', '90\n', '1234', '5678', '90\n']
+ lines = list(
+ wsgi.make_line_iter(
+ NativeStringIO(data),
+ limit=len(data),
+ buffer_size=bufsize,
+ cap_at_buffer=True,
+ )
+ )
+ assert len(lines[0]) == bufsize or lines[0].endswith("\n")
def test_range_wrapper():
- response = BaseResponse(b'Hello World')
+ response = BaseResponse(b"Hello World")
range_wrapper = _RangeWrapper(response.response, 6, 4)
- assert next(range_wrapper) == b'Worl'
+ assert next(range_wrapper) == b"Worl"
- response = BaseResponse(b'Hello World')
+ response = BaseResponse(b"Hello World")
range_wrapper = _RangeWrapper(response.response, 1, 0)
with pytest.raises(StopIteration):
next(range_wrapper)
- response = BaseResponse(b'Hello World')
+ response = BaseResponse(b"Hello World")
range_wrapper = _RangeWrapper(response.response, 6, 100)
- assert next(range_wrapper) == b'World'
+ assert next(range_wrapper) == b"World"
- response = BaseResponse((x for x in (b'He', b'll', b'o ', b'Wo', b'rl', b'd')))
+ response = BaseResponse((x for x in (b"He", b"ll", b"o ", b"Wo", b"rl", b"d")))
range_wrapper = _RangeWrapper(response.response, 6, 4)
assert not range_wrapper.seekable
- assert next(range_wrapper) == b'Wo'
- assert next(range_wrapper) == b'rl'
+ assert next(range_wrapper) == b"Wo"
+ assert next(range_wrapper) == b"rl"
- response = BaseResponse((x for x in (b'He', b'll', b'o W', b'o', b'rld')))
+ response = BaseResponse((x for x in (b"He", b"ll", b"o W", b"o", b"rld")))
range_wrapper = _RangeWrapper(response.response, 6, 4)
- assert next(range_wrapper) == b'W'
- assert next(range_wrapper) == b'o'
- assert next(range_wrapper) == b'rl'
+ assert next(range_wrapper) == b"W"
+ assert next(range_wrapper) == b"o"
+ assert next(range_wrapper) == b"rl"
with pytest.raises(StopIteration):
next(range_wrapper)
- response = BaseResponse((x for x in (b'Hello', b' World')))
+ response = BaseResponse((x for x in (b"Hello", b" World")))
range_wrapper = _RangeWrapper(response.response, 1, 1)
- assert next(range_wrapper) == b'e'
+ assert next(range_wrapper) == b"e"
with pytest.raises(StopIteration):
next(range_wrapper)
- resources = os.path.join(os.path.dirname(__file__), 'res')
+ resources = os.path.join(os.path.dirname(__file__), "res")
env = create_environ()
- with open(os.path.join(resources, 'test.txt'), 'rb') as f:
+ with open(os.path.join(resources, "test.txt"), "rb") as f:
response = BaseResponse(wrap_file(env, f))
range_wrapper = _RangeWrapper(response.response, 1, 2)
assert range_wrapper.seekable
- assert next(range_wrapper) == b'OU'
+ assert next(range_wrapper) == b"OU"
with pytest.raises(StopIteration):
next(range_wrapper)
- with open(os.path.join(resources, 'test.txt'), 'rb') as f:
+ with open(os.path.join(resources, "test.txt"), "rb") as f:
response = BaseResponse(wrap_file(env, f))
range_wrapper = _RangeWrapper(response.response, 2)
- assert next(range_wrapper) == b'UND\n'
+ assert next(range_wrapper) == b"UND\n"
with pytest.raises(StopIteration):
next(range_wrapper)
@@ -450,8 +488,8 @@ def test_closing_iterator():
# iterator. This ensures that ClosingIterator calls close on
# the iterable (the object), not the iterator.
def __iter__(self):
- self.start('200 OK', [('Content-Type', 'text/plain')])
- yield 'some content'
+ self.start("200 OK", [("Content-Type", "text/plain")])
+ yield "some content"
def close(self):
Namespace.got_close = True
@@ -462,9 +500,8 @@ def test_closing_iterator():
def app(environ, start_response):
return ClosingIterator(Response(environ, start_response), additional)
- app_iter, status, headers = run_wsgi_app(
- app, create_environ(), buffered=True)
+ app_iter, status, headers = run_wsgi_app(app, create_environ(), buffered=True)
- assert ''.join(app_iter) == 'some content'
+ assert "".join(app_iter) == "some content"
assert Namespace.got_close
assert Namespace.got_additional
diff --git a/tox.ini b/tox.ini
index 05bd20ee..902882cd 100644
--- a/tox.ini
+++ b/tox.ini
@@ -19,9 +19,9 @@ deps =
commands = coverage run -p -m pytest --tb=short --basetemp={envtmpdir} {posargs}
[testenv:stylecheck]
-deps = flake8
+deps = pre-commit
skip_install = true
-commands = flake8 {posargs}
+commands = pre-commit run --all-files --show-diff-on-failure
[testenv:docs-html]
deps = pallets-sphinx-themes
diff --git a/werkzeug-import-rewrite.py b/werkzeug-import-rewrite.py
index 0d909d36..81af6a55 100644
--- a/werkzeug-import-rewrite.py
+++ b/werkzeug-import-rewrite.py
@@ -10,118 +10,194 @@
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
-import sys
+import difflib
import os
-import re
import posixpath
-import difflib
+import re
+import sys
-_from_import_re = re.compile(r'(\s*(>>>|\.\.\.)?\s*)from werkzeug import\s+')
-_direct_usage = re.compile(r'(?<!`)(werkzeug\.)([a-zA-Z_][a-zA-Z0-9_]+)')
+_from_import_re = re.compile(r"(\s*(>>>|\.\.\.)?\s*)from werkzeug import\s+")
+_direct_usage = re.compile(r"(?<!`)(werkzeug\.)([a-zA-Z_][a-zA-Z0-9_]+)")
# not necessarily in sync with current werkzeug/__init__.py
all_by_module = {
- 'werkzeug.debug': ['DebuggedApplication'],
- 'werkzeug.local': ['Local', 'LocalManager', 'LocalProxy', 'LocalStack',
- 'release_local'],
- 'werkzeug.serving': ['run_simple'],
- 'werkzeug.test': ['Client', 'EnvironBuilder', 'create_environ',
- 'run_wsgi_app'],
- 'werkzeug.testapp': ['test_app'],
- 'werkzeug.exceptions': ['abort', 'Aborter'],
- 'werkzeug.urls': ['url_decode', 'url_encode', 'url_quote',
- 'url_quote_plus', 'url_unquote', 'url_unquote_plus',
- 'url_fix', 'Href', 'iri_to_uri', 'uri_to_iri'],
- 'werkzeug.formparser': ['parse_form_data'],
- 'werkzeug.utils': ['escape', 'environ_property', 'append_slash_redirect',
- 'redirect', 'cached_property', 'import_string',
- 'dump_cookie', 'parse_cookie', 'unescape',
- 'format_string', 'find_modules', 'header_property',
- 'html', 'xhtml', 'HTMLBuilder', 'validate_arguments',
- 'ArgumentValidationError', 'bind_arguments',
- 'secure_filename'],
- 'werkzeug.wsgi': ['get_current_url', 'get_host', 'pop_path_info',
- 'peek_path_info', 'SharedDataMiddleware',
- 'DispatcherMiddleware', 'ClosingIterator', 'FileWrapper',
- 'make_line_iter', 'LimitedStream', 'responder',
- 'wrap_file', 'extract_path_info'],
- 'werkzeug.datastructures': ['MultiDict', 'CombinedMultiDict', 'Headers',
- 'EnvironHeaders', 'ImmutableList',
- 'ImmutableDict', 'ImmutableMultiDict',
- 'TypeConversionDict',
- 'ImmutableTypeConversionDict', 'Accept',
- 'MIMEAccept', 'CharsetAccept',
- 'LanguageAccept', 'RequestCacheControl',
- 'ResponseCacheControl', 'ETags', 'HeaderSet',
- 'WWWAuthenticate', 'Authorization',
- 'FileMultiDict', 'CallbackDict', 'FileStorage',
- 'OrderedMultiDict', 'ImmutableOrderedMultiDict'
- ],
- 'werkzeug.useragents': ['UserAgent'],
- 'werkzeug.http': ['parse_etags', 'parse_date', 'http_date', 'cookie_date',
- 'parse_cache_control_header', 'is_resource_modified',
- 'parse_accept_header', 'parse_set_header', 'quote_etag',
- 'unquote_etag', 'generate_etag', 'dump_header',
- 'parse_list_header', 'parse_dict_header',
- 'parse_authorization_header',
- 'parse_www_authenticate_header', 'remove_entity_headers',
- 'is_entity_header', 'remove_hop_by_hop_headers',
- 'parse_options_header', 'dump_options_header',
- 'is_hop_by_hop_header', 'unquote_header_value',
- 'quote_header_value', 'HTTP_STATUS_CODES'],
- 'werkzeug.wrappers': ['BaseResponse', 'BaseRequest', 'Request', 'Response',
- 'AcceptMixin', 'ETagRequestMixin',
- 'ETagResponseMixin', 'ResponseStreamMixin',
- 'CommonResponseDescriptorsMixin', 'UserAgentMixin',
- 'AuthorizationMixin', 'WWWAuthenticateMixin',
- 'CommonRequestDescriptorsMixin'],
- 'werkzeug.security': ['generate_password_hash', 'check_password_hash'],
+ "werkzeug.debug": ["DebuggedApplication"],
+ "werkzeug.local": [
+ "Local",
+ "LocalManager",
+ "LocalProxy",
+ "LocalStack",
+ "release_local",
+ ],
+ "werkzeug.serving": ["run_simple"],
+ "werkzeug.test": ["Client", "EnvironBuilder", "create_environ", "run_wsgi_app"],
+ "werkzeug.testapp": ["test_app"],
+ "werkzeug.exceptions": ["abort", "Aborter"],
+ "werkzeug.urls": [
+ "url_decode",
+ "url_encode",
+ "url_quote",
+ "url_quote_plus",
+ "url_unquote",
+ "url_unquote_plus",
+ "url_fix",
+ "Href",
+ "iri_to_uri",
+ "uri_to_iri",
+ ],
+ "werkzeug.formparser": ["parse_form_data"],
+ "werkzeug.utils": [
+ "escape",
+ "environ_property",
+ "append_slash_redirect",
+ "redirect",
+ "cached_property",
+ "import_string",
+ "dump_cookie",
+ "parse_cookie",
+ "unescape",
+ "format_string",
+ "find_modules",
+ "header_property",
+ "html",
+ "xhtml",
+ "HTMLBuilder",
+ "validate_arguments",
+ "ArgumentValidationError",
+ "bind_arguments",
+ "secure_filename",
+ ],
+ "werkzeug.wsgi": [
+ "get_current_url",
+ "get_host",
+ "pop_path_info",
+ "peek_path_info",
+ "SharedDataMiddleware",
+ "DispatcherMiddleware",
+ "ClosingIterator",
+ "FileWrapper",
+ "make_line_iter",
+ "LimitedStream",
+ "responder",
+ "wrap_file",
+ "extract_path_info",
+ ],
+ "werkzeug.datastructures": [
+ "MultiDict",
+ "CombinedMultiDict",
+ "Headers",
+ "EnvironHeaders",
+ "ImmutableList",
+ "ImmutableDict",
+ "ImmutableMultiDict",
+ "TypeConversionDict",
+ "ImmutableTypeConversionDict",
+ "Accept",
+ "MIMEAccept",
+ "CharsetAccept",
+ "LanguageAccept",
+ "RequestCacheControl",
+ "ResponseCacheControl",
+ "ETags",
+ "HeaderSet",
+ "WWWAuthenticate",
+ "Authorization",
+ "FileMultiDict",
+ "CallbackDict",
+ "FileStorage",
+ "OrderedMultiDict",
+ "ImmutableOrderedMultiDict",
+ ],
+ "werkzeug.useragents": ["UserAgent"],
+ "werkzeug.http": [
+ "parse_etags",
+ "parse_date",
+ "http_date",
+ "cookie_date",
+ "parse_cache_control_header",
+ "is_resource_modified",
+ "parse_accept_header",
+ "parse_set_header",
+ "quote_etag",
+ "unquote_etag",
+ "generate_etag",
+ "dump_header",
+ "parse_list_header",
+ "parse_dict_header",
+ "parse_authorization_header",
+ "parse_www_authenticate_header",
+ "remove_entity_headers",
+ "is_entity_header",
+ "remove_hop_by_hop_headers",
+ "parse_options_header",
+ "dump_options_header",
+ "is_hop_by_hop_header",
+ "unquote_header_value",
+ "quote_header_value",
+ "HTTP_STATUS_CODES",
+ ],
+ "werkzeug.wrappers": [
+ "BaseResponse",
+ "BaseRequest",
+ "Request",
+ "Response",
+ "AcceptMixin",
+ "ETagRequestMixin",
+ "ETagResponseMixin",
+ "ResponseStreamMixin",
+ "CommonResponseDescriptorsMixin",
+ "UserAgentMixin",
+ "AuthorizationMixin",
+ "WWWAuthenticateMixin",
+ "CommonRequestDescriptorsMixin",
+ ],
+ "werkzeug.security": ["generate_password_hash", "check_password_hash"],
# the undocumented easteregg ;-)
- 'werkzeug._internal': ['_easteregg']
+ "werkzeug._internal": ["_easteregg"],
}
by_item = {}
-for module, names in all_by_module.iteritems():
+for module, names in all_by_module.items():
for name in names:
by_item[name] = module
def find_module(item):
- return by_item.get(item, 'werkzeug')
+ return by_item.get(item, "werkzeug")
def complete_fromlist(fromlist, lineiter):
fromlist = fromlist.strip()
if not fromlist:
return []
- if fromlist[0] == '(':
- if fromlist[-1] == ')':
- return fromlist[1:-1].strip().split(',')
- fromlist = fromlist[1:].strip().split(',')
+ if fromlist[0] == "(":
+ if fromlist[-1] == ")":
+ return fromlist[1:-1].strip().split(",")
+ fromlist = fromlist[1:].strip().split(",")
for line in lineiter:
line = line.strip()
abort = False
- if line.endswith(')'):
+ if line.endswith(")"):
line = line[:-1]
abort = True
- fromlist.extend(line.split(','))
+ fromlist.extend(line.split(","))
if abort:
break
return fromlist
- elif fromlist[-1] == '\\':
- fromlist = fromlist[:-1].strip().split(',')
+ elif fromlist[-1] == "\\":
+ fromlist = fromlist[:-1].strip().split(",")
for line in lineiter:
line = line.strip()
abort = True
- if line.endswith('\\'):
+ if line.endswith("\\"):
abort = False
line = line[:-1]
- fromlist.extend(line.split(','))
+ fromlist.extend(line.split(","))
if abort:
break
return fromlist
- return fromlist.split(',')
+ return fromlist.split(",")
def rewrite_from_imports(fromlist, indentation, lineiter):
@@ -132,43 +208,48 @@ def rewrite_from_imports(fromlist, indentation, lineiter):
continue
if len(item) == 1:
parsed_items.append((item[0], None))
- elif len(item) == 3 and item[1] == 'as':
+ elif len(item) == 3 and item[1] == "as":
parsed_items.append((item[0], item[2]))
else:
- raise ValueError('invalid syntax for import')
+ raise ValueError("invalid syntax for import")
new_imports = {}
for item, alias in parsed_items:
new_imports.setdefault(find_module(item), []).append((item, alias))
for module_name, items in sorted(new_imports.items()):
- fromlist_items = sorted(['%s%s' % (
- item,
- alias is not None and (' as ' + alias) or ''
- ) for (item, alias) in items], reverse=True)
+ fromlist_items = sorted(
+ [
+ "%s%s" % (item, alias is not None and (" as " + alias) or "")
+ for (item, alias) in items
+ ],
+ reverse=True,
+ )
- prefix = '%sfrom %s import ' % (indentation, module_name)
+ prefix = "%sfrom %s import " % (indentation, module_name)
item_buffer = []
while fromlist_items:
item_buffer.append(fromlist_items.pop())
- fromlist = ', '.join(item_buffer)
+ fromlist = ", ".join(item_buffer)
if len(fromlist) + len(prefix) > 79:
- yield prefix + ', '.join(item_buffer[:-1]) + ', \\'
+ yield prefix + ", ".join(item_buffer[:-1]) + ", \\"
item_buffer = [item_buffer[-1]]
# doctest continuations
- indentation = indentation.replace('>', '.')
- prefix = indentation + ' '
- yield prefix + ', '.join(item_buffer)
+ indentation = indentation.replace(">", ".")
+ prefix = indentation + " "
+ yield prefix + ", ".join(item_buffer)
def inject_imports(lines, imports):
pos = 0
for idx, line in enumerate(lines):
- if re.match(r'(from|import)\s+werkzeug', line):
+ if re.match(r"(from|import)\s+werkzeug", line):
pos = idx
break
- lines[pos:pos] = ['from %s import %s' % (mod, ', '.join(sorted(attrs)))
- for mod, attrs in sorted(imports.items())]
+ lines[pos:pos] = [
+ "from %s import %s" % (mod, ", ".join(sorted(attrs)))
+ for mod, attrs in sorted(imports.items())
+ ]
def rewrite_file(filename):
@@ -182,48 +263,48 @@ def rewrite_file(filename):
# rewrite from imports
match = _from_import_re.search(line)
if match is not None:
- fromlist = line[match.end():]
- new_file.extend(rewrite_from_imports(fromlist,
- match.group(1),
- lineiter))
+ fromlist = line[match.end() :]
+ new_file.extend(rewrite_from_imports(fromlist, match.group(1), lineiter))
continue
def _handle_match(match):
# rewrite attribute access to 'werkzeug'
attr = match.group(2)
mod = find_module(attr)
- if mod == 'werkzeug':
+ if mod == "werkzeug":
return match.group(0)
deferred_imports.setdefault(mod, []).append(attr)
return attr
+
new_file.append(_direct_usage.sub(_handle_match, line))
if deferred_imports:
inject_imports(new_file, deferred_imports)
for line in difflib.unified_diff(
- old_file, new_file,
- posixpath.normpath(posixpath.join('a', filename)),
- posixpath.normpath(posixpath.join('b', filename)),
- lineterm=''
+ old_file,
+ new_file,
+ posixpath.normpath(posixpath.join("a", filename)),
+ posixpath.normpath(posixpath.join("b", filename)),
+ lineterm="",
):
print(line)
def rewrite_in_folders(folders):
for folder in folders:
- for dirpath, dirnames, filenames in os.walk(folder):
+ for dirpath, _dirnames, filenames in os.walk(folder):
for filename in filenames:
filename = os.path.join(dirpath, filename)
- if filename.endswith(('.rst', '.py')):
+ if filename.endswith((".rst", ".py")):
rewrite_file(filename)
def main():
if len(sys.argv) == 1:
- print('usage: werkzeug-import-rewrite.py [folders]')
+ print("usage: werkzeug-import-rewrite.py [folders]")
sys.exit(1)
rewrite_in_folders(sys.argv[1:])
-if __name__ == '__main__':
+if __name__ == "__main__":
main()