summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--coverage/annotate.py2
-rw-r--r--coverage/codeunit.py33
-rw-r--r--coverage/html.py4
-rw-r--r--coverage/parser.py2
-rw-r--r--coverage/results.py6
-rw-r--r--tests/test_codeunit.py28
-rw-r--r--tests/test_parser.py4
7 files changed, 38 insertions, 41 deletions
diff --git a/coverage/annotate.py b/coverage/annotate.py
index dae9f4cf..5b96448a 100644
--- a/coverage/annotate.py
+++ b/coverage/annotate.py
@@ -61,7 +61,7 @@ class AnnotateReporter(Reporter):
i = 0
j = 0
covered = True
- source = cu.source_file().read()
+ source = cu.source()
for lineno, line in enumerate(source.splitlines(True), start=1):
while i < len(statements) and statements[i] < lineno:
i += 1
diff --git a/coverage/codeunit.py b/coverage/codeunit.py
index 23f49e8b..2ee19104 100644
--- a/coverage/codeunit.py
+++ b/coverage/codeunit.py
@@ -2,7 +2,7 @@
import os, re
-from coverage.backward import open_python_source, string_class, StringIO
+from coverage.backward import open_python_source, string_class
from coverage.misc import CoverageException, NoSource
from coverage.parser import CodeParser, PythonParser
from coverage.phystokens import source_token_lines, source_encoding
@@ -110,16 +110,16 @@ class CodeUnit(object):
root = os.path.splitdrive(self.name)[1]
return root.replace('\\', '_').replace('/', '_').replace('.', '_')
- def source_file(self):
- """Return an open file for reading the source of the code unit."""
+ def source(self):
+ """Return the source code, as a string."""
if os.path.exists(self.filename):
# A regular text file: open it.
- return open_python_source(self.filename)
+ return open_python_source(self.filename).read()
# Maybe it's in a zip file?
source = self.file_locator.get_zip_data(self.filename)
if source is not None:
- return StringIO(source)
+ return source
# Couldn't find source.
raise CoverageException(
@@ -139,8 +139,6 @@ class CodeUnit(object):
class PythonCodeUnit(CodeUnit):
"""Represents a Python file."""
- parser_class = PythonParser
-
def _adjust_filename(self, fname):
# .pyc files should always refer to a .py instead.
if fname.endswith(('.pyc', '.pyo')):
@@ -149,7 +147,13 @@ class PythonCodeUnit(CodeUnit):
fname = fname[:-9] + ".py"
return fname
- def find_source(self, filename):
+ def get_parser(self, exclude=None):
+ actual_filename, source = self._find_source(self.filename)
+ return PythonParser(
+ text=source, filename=actual_filename, exclude=exclude,
+ )
+
+ def _find_source(self, filename):
"""Find the source for `filename`.
Returns two values: the actual filename, and the source.
@@ -227,10 +231,8 @@ def mako_template_name(py_filename):
class MakoParser(CodeParser):
- def __init__(self, cu, text, filename, exclude):
+ def __init__(self, cu, exclude):
self.cu = cu
- self.text = text
- self.filename = filename
self.exclude = exclude
def parse_source(self):
@@ -261,14 +263,15 @@ class MakoParser(CodeParser):
class MakoCodeUnit(CodeUnit):
- parser_class = MakoParser
-
def __init__(self, *args, **kwargs):
super(MakoCodeUnit, self).__init__(*args, **kwargs)
self.mako_filename = mako_template_name(self.filename)
- def source_file(self):
- return open(self.mako_filename)
+ def source(self):
+ return open(self.mako_filename).read()
+
+ def get_parser(self, exclude=None):
+ return MakoParser(self, exclude)
def find_source(self, filename):
"""Find the source for `filename`.
diff --git a/coverage/html.py b/coverage/html.py
index 15afca8e..159ae581 100644
--- a/coverage/html.py
+++ b/coverage/html.py
@@ -148,9 +148,7 @@ class HtmlReporter(Reporter):
def html_file(self, cu, analysis):
"""Generate an HTML file for one source file."""
- source_file = cu.source_file()
- with source_file:
- source = source_file.read()
+ source = cu.source()
# Find out if the file on disk is already correct.
flat_rootname = cu.flat_rootname()
diff --git a/coverage/parser.py b/coverage/parser.py
index 88aad65f..5bb15466 100644
--- a/coverage/parser.py
+++ b/coverage/parser.py
@@ -30,7 +30,7 @@ class CodeParser(object):
class PythonParser(CodeParser):
"""Parse code to find executable lines, excluded lines, etc."""
- def __init__(self, cu, text=None, filename=None, exclude=None):
+ def __init__(self, text=None, filename=None, exclude=None):
"""
Source can be provided as `text`, the text itself, or `filename`, from
which the text will be read. Excluded lines are those that match
diff --git a/coverage/results.py b/coverage/results.py
index e422730d..ce9e0fa4 100644
--- a/coverage/results.py
+++ b/coverage/results.py
@@ -14,11 +14,7 @@ class Analysis(object):
self.code_unit = code_unit
self.filename = self.code_unit.filename
- actual_filename, source = self.code_unit.find_source(self.filename)
-
- self.parser = code_unit.parser_class(
- code_unit,
- text=source, filename=actual_filename,
+ self.parser = code_unit.get_parser(
exclude=self.coverage._exclude_regex('exclude')
)
self.statements, self.excluded = self.parser.parse_source()
diff --git a/tests/test_codeunit.py b/tests/test_codeunit.py
index e4912e11..fe82ea1c 100644
--- a/tests/test_codeunit.py
+++ b/tests/test_codeunit.py
@@ -31,9 +31,9 @@ class CodeUnitTest(CoverageTest):
self.assertEqual(acu[0].flat_rootname(), "aa_afile")
self.assertEqual(bcu[0].flat_rootname(), "aa_bb_bfile")
self.assertEqual(ccu[0].flat_rootname(), "aa_bb_cc_cfile")
- self.assertEqual(acu[0].source_file().read(), "# afile.py\n")
- self.assertEqual(bcu[0].source_file().read(), "# bfile.py\n")
- self.assertEqual(ccu[0].source_file().read(), "# cfile.py\n")
+ self.assertEqual(acu[0].source(), "# afile.py\n")
+ self.assertEqual(bcu[0].source(), "# bfile.py\n")
+ self.assertEqual(ccu[0].source(), "# cfile.py\n")
def test_odd_filenames(self):
acu = code_unit_factory("aa/afile.odd.py", FileLocator())
@@ -45,9 +45,9 @@ class CodeUnitTest(CoverageTest):
self.assertEqual(acu[0].flat_rootname(), "aa_afile_odd")
self.assertEqual(bcu[0].flat_rootname(), "aa_bb_bfile_odd")
self.assertEqual(b2cu[0].flat_rootname(), "aa_bb_odd_bfile")
- self.assertEqual(acu[0].source_file().read(), "# afile.odd.py\n")
- self.assertEqual(bcu[0].source_file().read(), "# bfile.odd.py\n")
- self.assertEqual(b2cu[0].source_file().read(), "# bfile.py\n")
+ self.assertEqual(acu[0].source(), "# afile.odd.py\n")
+ self.assertEqual(bcu[0].source(), "# bfile.odd.py\n")
+ self.assertEqual(b2cu[0].source(), "# bfile.py\n")
def test_modules(self):
import aa, aa.bb, aa.bb.cc
@@ -58,9 +58,9 @@ class CodeUnitTest(CoverageTest):
self.assertEqual(cu[0].flat_rootname(), "aa")
self.assertEqual(cu[1].flat_rootname(), "aa_bb")
self.assertEqual(cu[2].flat_rootname(), "aa_bb_cc")
- self.assertEqual(cu[0].source_file().read(), "# aa\n")
- self.assertEqual(cu[1].source_file().read(), "# bb\n")
- self.assertEqual(cu[2].source_file().read(), "") # yes, empty
+ self.assertEqual(cu[0].source(), "# aa\n")
+ self.assertEqual(cu[1].source(), "# bb\n")
+ self.assertEqual(cu[2].source(), "") # yes, empty
def test_module_files(self):
import aa.afile, aa.bb.bfile, aa.bb.cc.cfile
@@ -72,9 +72,9 @@ class CodeUnitTest(CoverageTest):
self.assertEqual(cu[0].flat_rootname(), "aa_afile")
self.assertEqual(cu[1].flat_rootname(), "aa_bb_bfile")
self.assertEqual(cu[2].flat_rootname(), "aa_bb_cc_cfile")
- self.assertEqual(cu[0].source_file().read(), "# afile.py\n")
- self.assertEqual(cu[1].source_file().read(), "# bfile.py\n")
- self.assertEqual(cu[2].source_file().read(), "# cfile.py\n")
+ self.assertEqual(cu[0].source(), "# afile.py\n")
+ self.assertEqual(cu[1].source(), "# bfile.py\n")
+ self.assertEqual(cu[2].source(), "# cfile.py\n")
def test_comparison(self):
acu = code_unit_factory("aa/afile.py", FileLocator())[0]
@@ -97,7 +97,7 @@ class CodeUnitTest(CoverageTest):
self.assert_doesnt_exist(egg1.__file__)
cu = code_unit_factory([egg1, egg1.egg1], FileLocator())
- self.assertEqual(cu[0].source_file().read(), "")
- self.assertEqual(cu[1].source_file().read().split("\n")[0],
+ self.assertEqual(cu[0].source(), "")
+ self.assertEqual(cu[1].source().split("\n")[0],
"# My egg file!"
)
diff --git a/tests/test_parser.py b/tests/test_parser.py
index 5b90f342..a392ea03 100644
--- a/tests/test_parser.py
+++ b/tests/test_parser.py
@@ -13,7 +13,7 @@ class PythonParserTest(CoverageTest):
def parse_source(self, text):
"""Parse `text` as source, and return the `PythonParser` used."""
text = textwrap.dedent(text)
- parser = PythonParser(None, text=text, exclude="nocover")
+ parser = PythonParser(text=text, exclude="nocover")
parser.parse_source()
return parser
@@ -98,7 +98,7 @@ class ParserFileTest(CoverageTest):
def parse_file(self, filename):
"""Parse `text` as source, and return the `PythonParser` used."""
- parser = PythonParser(None, filename=filename, exclude="nocover")
+ parser = PythonParser(filename=filename, exclude="nocover")
parser.parse_source()
return parser