summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndi Albrecht <albrecht.andi@gmail.com>2017-09-24 09:14:35 +0200
committerAndi Albrecht <albrecht.andi@gmail.com>2017-09-24 09:14:35 +0200
commitbf9ce73e3720bdbf1cc671f35f4f299511d59650 (patch)
tree258258cc2fae4ec7f80f5e0a4c7fa89f05eb7e4b
parent097478e47fbc0423118f82a0a7b458c2e9dbea7b (diff)
downloadsqlparse-bf9ce73e3720bdbf1cc671f35f4f299511d59650.tar.gz
Close files during tests.
-rwxr-xr-xsqlparse/cli.py7
-rw-r--r--tests/test_cli.py9
-rw-r--r--tests/test_regressions.py12
3 files changed, 17 insertions, 11 deletions
diff --git a/sqlparse/cli.py b/sqlparse/cli.py
index 0b5c204..ad6bc7a 100755
--- a/sqlparse/cli.py
+++ b/sqlparse/cli.py
@@ -154,14 +154,17 @@ def main(args=None):
sys.stdin.buffer, encoding=args.encoding).read()
else:
try:
- data = ''.join(open(args.filename, 'r', args.encoding).readlines())
+ with open(args.filename, 'r', args.encoding) as f:
+ data = ''.join(f.readlines())
except IOError as e:
return _error(
u'Failed to read {0}: {1}'.format(args.filename, e))
+ close_stream = False
if args.outfile:
try:
stream = open(args.outfile, 'w', args.encoding)
+ close_stream = True
except IOError as e:
return _error(u'Failed to open {0}: {1}'.format(args.outfile, e))
else:
@@ -176,4 +179,6 @@ def main(args=None):
s = sqlparse.format(data, **formatter_opts)
stream.write(s)
stream.flush()
+ if close_stream:
+ stream.close()
return 0
diff --git a/tests/test_cli.py b/tests/test_cli.py
index c1a5a75..18c6fcb 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -127,10 +127,11 @@ def test_encoding_stdin_gbk(filepath, load_file, capfd):
path = filepath('encoding_gbk.sql')
expected = load_file('encoding_gbk.sql', 'gbk')
old_stdin = sys.stdin
- sys.stdin = open(path, 'r')
- sys.stdout.encoding = 'gbk'
- sqlparse.cli.main(['-', '--encoding', 'gbk'])
- sys.stdin = old_stdin
+ with open(path, 'r') as stream:
+ sys.stdin = stream
+ sys.stdout.encoding = 'gbk'
+ sqlparse.cli.main(['-', '--encoding', 'gbk'])
+ sys.stdin = old_stdin
out, _ = capfd.readouterr()
assert out == expected
diff --git a/tests/test_regressions.py b/tests/test_regressions.py
index 89828f0..406328c 100644
--- a/tests/test_regressions.py
+++ b/tests/test_regressions.py
@@ -186,9 +186,9 @@ def test_format_accepts_encoding(load_file):
def test_stream(get_stream):
- stream = get_stream("stream.sql")
- p = sqlparse.parse(stream)[0]
- assert p.get_type() == 'INSERT'
+ with get_stream("stream.sql") as stream:
+ p = sqlparse.parse(stream)[0]
+ assert p.get_type() == 'INSERT'
def test_issue90():
@@ -238,9 +238,9 @@ def test_null_with_as():
def test_issue190_open_file(filepath):
path = filepath('stream.sql')
- stream = open(path)
- p = sqlparse.parse(stream)[0]
- assert p.get_type() == 'INSERT'
+ with open(path) as stream:
+ p = sqlparse.parse(stream)[0]
+ assert p.get_type() == 'INSERT'
def test_issue193_splitting_function():