summaryrefslogtreecommitdiff
path: root/coverage/parser.py
diff options
context:
space:
mode:
authorNed Batchelder <ned@nedbatchelder.com>2016-01-02 10:18:04 -0500
committerNed Batchelder <ned@nedbatchelder.com>2016-01-02 10:18:04 -0500
commitbaf18bed45cbd943f379f9ca4e7747fb607552c8 (patch)
tree8dd7986ce95861ebde3a2a95837a7ad20c01a96e /coverage/parser.py
parent82dae969e9318e35bccfc08c0e652cbb931403c6 (diff)
downloadpython-coveragepy-baf18bed45cbd943f379f9ca4e7747fb607552c8.tar.gz
Handle yield-from and await. All tests pass
Diffstat (limited to 'coverage/parser.py')
-rw-r--r--coverage/parser.py88
1 files changed, 60 insertions, 28 deletions
diff --git a/coverage/parser.py b/coverage/parser.py
index 2396fb8..0462802 100644
--- a/coverage/parser.py
+++ b/coverage/parser.py
@@ -327,11 +327,17 @@ class AstArcAnalyzer(object):
def __init__(self, text):
self.root_node = ast.parse(text)
if int(os.environ.get("COVERAGE_ASTDUMP", 0)):
+ # Dump the AST so that failing tests have helpful output.
ast_dump(self.root_node)
self.arcs = None
self.block_stack = []
+ def collect_arcs(self):
+ self.arcs = set()
+ self.add_arcs_for_code_objects(self.root_node)
+ return self.arcs
+
def blocks(self):
"""Yield the blocks in nearest-to-farthest order."""
return reversed(self.block_stack)
@@ -361,16 +367,19 @@ class AstArcAnalyzer(object):
def line_default(self, node):
return node.lineno
- def collect_arcs(self):
- self.arcs = set()
- self.add_arcs_for_code_objects(self.root_node)
- return self.arcs
-
def add_arcs(self, node):
- """add the arcs for `node`.
+ """Add the arcs for `node`.
Return a set of line numbers, exits from this node to the next.
"""
+ # Yield-froms and awaits can appear anywhere.
+ # TODO: this is probably over-doing it, and too expensive. Can we
+ # instrument the ast walking to see how many nodes we are revisiting?
+ if isinstance(node, ast.stmt):
+ for name, value in ast.iter_fields(node):
+ if isinstance(value, ast.expr) and self.contains_return_expression(value):
+ self.process_return_exits([self.line_for_node(node)])
+ break
node_name = node.__class__.__name__
handler = getattr(self, "handle_" + node_name, self.handle_default)
return handler(node)
@@ -404,6 +413,7 @@ class AstArcAnalyzer(object):
# TODO: multi-line listcomps
# TODO: nested function definitions
# TODO: multiple `except` clauses
+ # TODO: return->finally
def process_break_exits(self, exits):
for block in self.blocks():
@@ -443,6 +453,7 @@ class AstArcAnalyzer(object):
def process_return_exits(self, exits):
for block in self.blocks():
+ # TODO: need a check here for TryBlock
if isinstance(block, FunctionBlock):
# TODO: what if there is no enclosing function?
for exit in exits:
@@ -587,6 +598,7 @@ class AstArcAnalyzer(object):
def handle_default(self, node):
node_name = node.__class__.__name__
if node_name not in ["Assign", "Assert", "AugAssign", "Expr", "Import", "Pass", "Print"]:
+ # TODO: put 1/0 here to find unhandled nodes.
print("*** Unhandled: {0}".format(node))
return set([self.line_for_node(node)])
@@ -628,6 +640,14 @@ class AstArcAnalyzer(object):
self.arcs.add((start, -start))
# TODO: test multi-line lambdas
+ def contains_return_expression(self, node):
+ """Is there a yield-from or await in `node` someplace?"""
+ for child in ast.walk(node):
+ if child.__class__.__name__ in ["YieldFrom", "Await"]:
+ return True
+
+ return False
+
## Opcodes that guide the ByteParser.
@@ -1045,7 +1065,13 @@ class Chunk(object):
)
-SKIP_FIELDS = ["ctx"]
+SKIP_DUMP_FIELDS = ["ctx"]
+
+def is_simple_value(value):
+ return (
+ value in [None, [], (), {}, set()] or
+ isinstance(value, (string_class, int, float))
+ )
def ast_dump(node, depth=0):
indent = " " * depth
@@ -1055,30 +1081,36 @@ def ast_dump(node, depth=0):
lineno = getattr(node, "lineno", None)
if lineno is not None:
- linemark = " @ {0}".format(lineno)
+ linemark = " @ {0}".format(node.lineno)
else:
linemark = ""
- print("{0}<{1}{2}".format(indent, node.__class__.__name__, linemark))
-
- indent += " "
- for field_name, value in ast.iter_fields(node):
- if field_name in SKIP_FIELDS:
- continue
- prefix = "{0}{1}:".format(indent, field_name)
- if value is None:
- print("{0} None".format(prefix))
- elif isinstance(value, (string_class, int, float)):
- print("{0} {1!r}".format(prefix, value))
- elif isinstance(value, list):
- if value == []:
- print("{0} []".format(prefix))
- else:
+ head = "{0}<{1}{2}".format(indent, node.__class__.__name__, linemark)
+
+ named_fields = [
+ (name, value)
+ for name, value in ast.iter_fields(node)
+ if name not in SKIP_DUMP_FIELDS
+ ]
+ if not named_fields:
+ print("{0}>".format(head))
+ elif len(named_fields) == 1 and is_simple_value(named_fields[0][1]):
+ field_name, value = named_fields[0]
+ print("{0} {1}: {2!r}>".format(head, field_name, value))
+ else:
+ print(head)
+ print("{0}# mro: {1}".format(indent, ", ".join(c.__name__ for c in node.__class__.__mro__[1:])))
+ next_indent = indent + " "
+ for field_name, value in named_fields:
+ prefix = "{0}{1}:".format(next_indent, field_name)
+ if is_simple_value(value):
+ print("{0} {1!r}".format(prefix, value))
+ elif isinstance(value, list):
print("{0} [".format(prefix))
for n in value:
ast_dump(n, depth + 8)
- print("{0}]".format(indent))
- else:
- print(prefix)
- ast_dump(value, depth + 8)
+ print("{0}]".format(next_indent))
+ else:
+ print(prefix)
+ ast_dump(value, depth + 8)
- print("{0}>".format(" " * depth))
+ print("{0}>".format(indent))