summaryrefslogtreecommitdiff
path: root/Cython/Compiler/ParseTreeTransforms.py
diff options
context:
space:
mode:
Diffstat (limited to 'Cython/Compiler/ParseTreeTransforms.py')
-rw-r--r--Cython/Compiler/ParseTreeTransforms.py995
1 files changed, 801 insertions, 194 deletions
diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py
index 0e86d5b0e..8008cba9a 100644
--- a/Cython/Compiler/ParseTreeTransforms.py
+++ b/Cython/Compiler/ParseTreeTransforms.py
@@ -1,3 +1,5 @@
+# cython: language_level=3str
+
from __future__ import absolute_import
import cython
@@ -54,6 +56,12 @@ class SkipDeclarations(object):
def visit_CStructOrUnionDefNode(self, node):
return node
+ def visit_CppClassNode(self, node):
+ if node.visibility != "extern":
+ # Need to traverse methods.
+ self.visitchildren(node)
+ return node
+
class NormalizeTree(CythonTransform):
"""
@@ -81,6 +89,13 @@ class NormalizeTree(CythonTransform):
self.is_in_statlist = False
self.is_in_expr = False
+ def visit_ModuleNode(self, node):
+ self.visitchildren(node)
+ if not isinstance(node.body, Nodes.StatListNode):
+ # This can happen when the body only consists of a single (unused) declaration and no statements.
+ node.body = Nodes.StatListNode(pos=node.pos, stats=[node.body])
+ return node
+
def visit_ExprNode(self, node):
stacktmp = self.is_in_expr
self.is_in_expr = True
@@ -170,8 +185,9 @@ class PostParse(ScopeTrackingTransform):
Note: Currently Parsing.py does a lot of interpretation and
reorganization that can be refactored into this transform
if a more pure Abstract Syntax Tree is wanted.
- """
+ - Some invalid uses of := assignment expressions are detected
+ """
def __init__(self, context):
super(PostParse, self).__init__(context)
self.specialattribute_handlers = {
@@ -203,7 +219,9 @@ class PostParse(ScopeTrackingTransform):
node.def_node = Nodes.DefNode(
node.pos, name=node.name, doc=None,
args=[], star_arg=None, starstar_arg=None,
- body=node.loop, is_async_def=collector.has_await)
+ body=node.loop, is_async_def=collector.has_await,
+ is_generator_expression=True)
+ _AssignmentExpressionChecker.do_checks(node.loop, scope_is_class=self.scope_type in ("pyclass", "cclass"))
self.visitchildren(node)
return node
@@ -214,6 +232,7 @@ class PostParse(ScopeTrackingTransform):
collector.visitchildren(node.loop)
if collector.has_await:
node.has_local_scope = True
+ _AssignmentExpressionChecker.do_checks(node.loop, scope_is_class=self.scope_type in ("pyclass", "cclass"))
self.visitchildren(node)
return node
@@ -246,7 +265,7 @@ class PostParse(ScopeTrackingTransform):
if decl is not declbase:
raise PostParseError(decl.pos, ERR_INVALID_SPECIALATTR_TYPE)
handler(decl)
- continue # Remove declaration
+ continue # Remove declaration
raise PostParseError(decl.pos, ERR_CDEF_INCLASS)
first_assignment = self.scope_type != 'module'
stats.append(Nodes.SingleAssignmentNode(node.pos,
@@ -349,6 +368,141 @@ class PostParse(ScopeTrackingTransform):
self.visitchildren(node)
return node
+ def visit_AssertStatNode(self, node):
+ """Extract the exception raising into a RaiseStatNode to simplify GIL handling.
+ """
+ if node.exception is None:
+ node.exception = Nodes.RaiseStatNode(
+ node.pos,
+ exc_type=ExprNodes.NameNode(node.pos, name=EncodedString("AssertionError")),
+ exc_value=node.value,
+ exc_tb=None,
+ cause=None,
+ builtin_exc_name="AssertionError",
+ wrap_tuple_value=True,
+ )
+ node.value = None
+ self.visitchildren(node)
+ return node
+
+class _AssignmentExpressionTargetNameFinder(TreeVisitor):
+ def __init__(self):
+ super(_AssignmentExpressionTargetNameFinder, self).__init__()
+ self.target_names = {}
+
+ def find_target_names(self, target):
+ if target.is_name:
+ return [target.name]
+ elif target.is_sequence_constructor:
+ names = []
+ for arg in target.args:
+ names.extend(self.find_target_names(arg))
+ return names
+ # other targets are possible, but it isn't necessary to investigate them here
+ return []
+
+ def visit_ForInStatNode(self, node):
+ self.target_names[node] = tuple(self.find_target_names(node.target))
+ self.visitchildren(node)
+
+ def visit_ComprehensionNode(self, node):
+ pass # don't recurse into nested comprehensions
+
+ def visit_LambdaNode(self, node):
+ pass # don't recurse into nested lambdas/generator expressions
+
+ def visit_Node(self, node):
+ self.visitchildren(node)
+
+
+class _AssignmentExpressionChecker(TreeVisitor):
+ """
+ Enforces rules on AssignmentExpressions within generator expressions and comprehensions
+ """
+ def __init__(self, loop_node, scope_is_class):
+ super(_AssignmentExpressionChecker, self).__init__()
+
+ target_name_finder = _AssignmentExpressionTargetNameFinder()
+ target_name_finder.visit(loop_node)
+ self.target_names_dict = target_name_finder.target_names
+ self.in_iterator = False
+ self.in_nested_generator = False
+ self.scope_is_class = scope_is_class
+ self.current_target_names = ()
+ self.all_target_names = set()
+ for names in self.target_names_dict.values():
+ self.all_target_names.update(names)
+
+ def _reset_state(self):
+ old_state = (self.in_iterator, self.in_nested_generator, self.scope_is_class, self.all_target_names, self.current_target_names)
+ # note: not resetting self.in_iterator here, see visit_LambdaNode() below
+ self.in_nested_generator = False
+ self.scope_is_class = False
+ self.current_target_names = ()
+ self.all_target_names = set()
+ return old_state
+
+ def _set_state(self, old_state):
+ self.in_iterator, self.in_nested_generator, self.scope_is_class, self.all_target_names, self.current_target_names = old_state
+
+ @classmethod
+ def do_checks(cls, loop_node, scope_is_class):
+ checker = cls(loop_node, scope_is_class)
+ checker.visit(loop_node)
+
+ def visit_ForInStatNode(self, node):
+ if self.in_nested_generator:
+ self.visitchildren(node) # once nested, don't do anything special
+ return
+
+ current_target_names = self.current_target_names
+ target_name = self.target_names_dict.get(node, None)
+ if target_name:
+ self.current_target_names += target_name
+
+ self.in_iterator = True
+ self.visit(node.iterator)
+ self.in_iterator = False
+ self.visitchildren(node, exclude=("iterator",))
+
+ self.current_target_names = current_target_names
+
+ def visit_AssignmentExpressionNode(self, node):
+ if self.in_iterator:
+ error(node.pos, "assignment expression cannot be used in a comprehension iterable expression")
+ if self.scope_is_class:
+ error(node.pos, "assignment expression within a comprehension cannot be used in a class body")
+ if node.target_name in self.current_target_names:
+ error(node.pos, "assignment expression cannot rebind comprehension iteration variable '%s'" %
+ node.target_name)
+ elif node.target_name in self.all_target_names:
+ error(node.pos, "comprehension inner loop cannot rebind assignment expression target '%s'" %
+ node.target_name)
+
+ def visit_LambdaNode(self, node):
+ # Don't reset "in_iterator" - an assignment expression in a lambda in an
+ # iterator is explicitly tested by the Python testcases and banned.
+ old_state = self._reset_state()
+ # the lambda node's "def_node" is not set up at this point, so we need to recurse into it explicitly.
+ self.visit(node.result_expr)
+ self._set_state(old_state)
+
+ def visit_ComprehensionNode(self, node):
+ in_nested_generator = self.in_nested_generator
+ self.in_nested_generator = True
+ self.visitchildren(node)
+ self.in_nested_generator = in_nested_generator
+
+ def visit_GeneratorExpressionNode(self, node):
+ in_nested_generator = self.in_nested_generator
+ self.in_nested_generator = True
+ # def_node isn't set up yet, so we need to visit the loop directly.
+ self.visit(node.loop)
+ self.in_nested_generator = in_nested_generator
+
+ def visit_Node(self, node):
+ self.visitchildren(node)
+
def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
"""Replace rhs items by LetRefNodes if they appear more than once.
@@ -419,7 +573,7 @@ def sort_common_subsequences(items):
return b.is_sequence_constructor and contains(b.args, a)
for pos, item in enumerate(items):
- key = item[1] # the ResultRefNode which has already been injected into the sequences
+ key = item[1] # the ResultRefNode which has already been injected into the sequences
new_pos = pos
for i in range(pos-1, -1, -1):
if lower_than(key, items[i][0]):
@@ -449,7 +603,7 @@ def flatten_parallel_assignments(input, output):
# recursively, so that nested structures get matched as well.
rhs = input[-1]
if (not (rhs.is_sequence_constructor or isinstance(rhs, ExprNodes.UnicodeNode))
- or not sum([lhs.is_sequence_constructor for lhs in input[:-1]])):
+ or not sum([lhs.is_sequence_constructor for lhs in input[:-1]])):
output.append(input)
return
@@ -533,7 +687,7 @@ def map_starred_assignment(lhs_targets, starred_assignments, lhs_args, rhs_args)
targets.append(expr)
# the starred target itself, must be assigned a (potentially empty) list
- target = lhs_args[starred].target # unpack starred node
+ target = lhs_args[starred].target # unpack starred node
starred_rhs = rhs_args[starred:]
if lhs_remaining:
starred_rhs = starred_rhs[:-lhs_remaining]
@@ -579,19 +733,19 @@ class PxdPostParse(CythonTransform, SkipDeclarations):
err = self.ERR_INLINE_ONLY
if (isinstance(node, Nodes.DefNode) and self.scope_type == 'cclass'
- and node.name in ('__getbuffer__', '__releasebuffer__')):
- err = None # allow these slots
+ and node.name in ('__getbuffer__', '__releasebuffer__')):
+ err = None # allow these slots
if isinstance(node, Nodes.CFuncDefNode):
if (u'inline' in node.modifiers and
- self.scope_type in ('pxd', 'cclass')):
+ self.scope_type in ('pxd', 'cclass')):
node.inline_in_pxd = True
if node.visibility != 'private':
err = self.ERR_NOGO_WITH_INLINE % node.visibility
elif node.api:
err = self.ERR_NOGO_WITH_INLINE % 'api'
else:
- err = None # allow inline function
+ err = None # allow inline function
else:
err = self.ERR_INLINE_ONLY
@@ -630,6 +784,9 @@ class InterpretCompilerDirectives(CythonTransform):
- Command-line arguments overriding these
- @cython.directivename decorators
- with cython.directivename: statements
+ - replaces "cython.compiled" with BoolNode(value=True)
+ allowing unreachable blocks to be removed at a fairly early stage
+ before cython typing rules are forced on applied
This transform is responsible for interpreting these various sources
and store the directive in two ways:
@@ -668,17 +825,27 @@ class InterpretCompilerDirectives(CythonTransform):
'operator.comma' : ExprNodes.c_binop_constructor(','),
}
- special_methods = set(['declare', 'union', 'struct', 'typedef',
- 'sizeof', 'cast', 'pointer', 'compiled',
- 'NULL', 'fused_type', 'parallel'])
+ special_methods = {
+ 'declare', 'union', 'struct', 'typedef',
+ 'sizeof', 'cast', 'pointer', 'compiled',
+ 'NULL', 'fused_type', 'parallel',
+ }
special_methods.update(unop_method_nodes)
- valid_parallel_directives = set([
+ valid_cython_submodules = {
+ 'cimports',
+ 'dataclasses',
+ 'operator',
+ 'parallel',
+ 'view',
+ }
+
+ valid_parallel_directives = {
"parallel",
"prange",
"threadid",
#"threadsavailable",
- ])
+ }
def __init__(self, context, compilation_directive_defaults):
super(InterpretCompilerDirectives, self).__init__(context)
@@ -701,6 +868,44 @@ class InterpretCompilerDirectives(CythonTransform):
error(pos, "Invalid directive: '%s'." % (directive,))
return True
+ def _check_valid_cython_module(self, pos, module_name):
+ if not module_name.startswith("cython."):
+ return
+ submodule = module_name.split('.', 2)[1]
+ if submodule in self.valid_cython_submodules:
+ return
+
+ extra = ""
+ # This is very rarely used, so don't waste space on static tuples.
+ hints = [
+ line.split() for line in """\
+ imp cimports
+ cimp cimports
+ para parallel
+ parra parallel
+ dataclass dataclasses
+ """.splitlines()[:-1]
+ ]
+ for wrong, correct in hints:
+ if module_name.startswith("cython." + wrong):
+ extra = "Did you mean 'cython.%s' ?" % correct
+ break
+ if not extra:
+ is_simple_cython_name = submodule in Options.directive_types
+ if not is_simple_cython_name and not submodule.startswith("_"):
+ # Try to find it in the Shadow module (i.e. the pure Python namespace of cython.*).
+ # FIXME: use an internal reference of "cython.*" names instead of Shadow.py
+ from .. import Shadow
+ is_simple_cython_name = hasattr(Shadow, submodule)
+ if is_simple_cython_name:
+ extra = "Instead, use 'import cython' and then 'cython.%s'." % submodule
+
+ error(pos, "'%s' is not a valid cython.* module%s%s" % (
+ module_name,
+ ". " if extra else "",
+ extra,
+ ))
+
# Set up processing and handle the cython: comments.
def visit_ModuleNode(self, node):
for key in sorted(node.directive_comments):
@@ -717,6 +922,12 @@ class InterpretCompilerDirectives(CythonTransform):
node.cython_module_names = self.cython_module_names
return node
+ def visit_CompilerDirectivesNode(self, node):
+ old_directives, self.directives = self.directives, node.directives
+ self.visitchildren(node)
+ self.directives = old_directives
+ return node
+
# The following four functions track imports and cimports that
# begin with "cython"
def is_cython_directive(self, name):
@@ -749,22 +960,36 @@ class InterpretCompilerDirectives(CythonTransform):
return result
def visit_CImportStatNode(self, node):
- if node.module_name == u"cython":
+ module_name = node.module_name
+ if module_name == u"cython.cimports":
+ error(node.pos, "Cannot cimport the 'cython.cimports' package directly, only submodules.")
+ if module_name.startswith(u"cython.cimports."):
+ if node.as_name and node.as_name != u'cython':
+ node.module_name = module_name[len(u"cython.cimports."):]
+ return node
+ error(node.pos,
+ "Python cimports must use 'from cython.cimports... import ...'"
+ " or 'import ... as ...', not just 'import ...'")
+
+ if module_name == u"cython":
self.cython_module_names.add(node.as_name or u"cython")
- elif node.module_name.startswith(u"cython."):
- if node.module_name.startswith(u"cython.parallel."):
+ elif module_name.startswith(u"cython."):
+ if module_name.startswith(u"cython.parallel."):
error(node.pos, node.module_name + " is not a module")
- if node.module_name == u"cython.parallel":
+ else:
+ self._check_valid_cython_module(node.pos, module_name)
+
+ if module_name == u"cython.parallel":
if node.as_name and node.as_name != u"cython":
- self.parallel_directives[node.as_name] = node.module_name
+ self.parallel_directives[node.as_name] = module_name
else:
self.cython_module_names.add(u"cython")
self.parallel_directives[
- u"cython.parallel"] = node.module_name
+ u"cython.parallel"] = module_name
self.module_scope.use_utility_code(
UtilityCode.load_cached("InitThreads", "ModuleSetupCode.c"))
elif node.as_name:
- self.directive_names[node.as_name] = node.module_name[7:]
+ self.directive_names[node.as_name] = module_name[7:]
else:
self.cython_module_names.add(u"cython")
# if this cimport was a compiler directive, we don't
@@ -773,26 +998,31 @@ class InterpretCompilerDirectives(CythonTransform):
return node
def visit_FromCImportStatNode(self, node):
- if not node.relative_level and (
- node.module_name == u"cython" or node.module_name.startswith(u"cython.")):
- submodule = (node.module_name + u".")[7:]
+ module_name = node.module_name
+ if module_name == u"cython.cimports" or module_name.startswith(u"cython.cimports."):
+ # only supported for convenience
+ return self._create_cimport_from_import(
+ node.pos, module_name, node.relative_level, node.imported_names)
+ elif not node.relative_level and (
+ module_name == u"cython" or module_name.startswith(u"cython.")):
+ self._check_valid_cython_module(node.pos, module_name)
+ submodule = (module_name + u".")[7:]
newimp = []
-
- for pos, name, as_name, kind in node.imported_names:
+ for pos, name, as_name in node.imported_names:
full_name = submodule + name
qualified_name = u"cython." + full_name
-
if self.is_parallel_directive(qualified_name, node.pos):
# from cython cimport parallel, or
# from cython.parallel cimport parallel, prange, ...
self.parallel_directives[as_name or name] = qualified_name
elif self.is_cython_directive(full_name):
self.directive_names[as_name or name] = full_name
- if kind is not None:
- self.context.nonfatal_error(PostParseError(pos,
- "Compiler directive imports must be plain imports"))
+ elif full_name in ['dataclasses', 'typing']:
+ self.directive_names[as_name or name] = full_name
+ # unlike many directives, still treat it as a regular module
+ newimp.append((pos, name, as_name))
else:
- newimp.append((pos, name, as_name, kind))
+ newimp.append((pos, name, as_name))
if not newimp:
return None
@@ -801,9 +1031,18 @@ class InterpretCompilerDirectives(CythonTransform):
return node
def visit_FromImportStatNode(self, node):
- if (node.module.module_name.value == u"cython") or \
- node.module.module_name.value.startswith(u"cython."):
- submodule = (node.module.module_name.value + u".")[7:]
+ import_node = node.module
+ module_name = import_node.module_name.value
+ if module_name == u"cython.cimports" or module_name.startswith(u"cython.cimports."):
+ imported_names = []
+ for name, name_node in node.items:
+ imported_names.append(
+ (name_node.pos, name, None if name == name_node.name else name_node.name))
+ return self._create_cimport_from_import(
+ node.pos, module_name, import_node.level, imported_names)
+ elif module_name == u"cython" or module_name.startswith(u"cython."):
+ self._check_valid_cython_module(import_node.module_name.pos, module_name)
+ submodule = (module_name + u".")[7:]
newimp = []
for name, name_node in node.items:
full_name = submodule + name
@@ -819,20 +1058,34 @@ class InterpretCompilerDirectives(CythonTransform):
node.items = newimp
return node
+ def _create_cimport_from_import(self, node_pos, module_name, level, imported_names):
+ if module_name == u"cython.cimports" or module_name.startswith(u"cython.cimports."):
+ module_name = EncodedString(module_name[len(u"cython.cimports."):]) # may be empty
+
+ if module_name:
+ # from cython.cimports.a.b import x, y, z => from a.b cimport x, y, z
+ return Nodes.FromCImportStatNode(
+ node_pos, module_name=module_name,
+ relative_level=level,
+ imported_names=imported_names)
+ else:
+ # from cython.cimports import x, y, z => cimport x; cimport y; cimport z
+ return [
+ Nodes.CImportStatNode(
+ pos,
+ module_name=dotted_name,
+ as_name=as_name,
+ is_absolute=level == 0)
+ for pos, dotted_name, as_name in imported_names
+ ]
+
def visit_SingleAssignmentNode(self, node):
if isinstance(node.rhs, ExprNodes.ImportNode):
module_name = node.rhs.module_name.value
- is_parallel = (module_name + u".").startswith(u"cython.parallel.")
-
- if module_name != u"cython" and not is_parallel:
+ if module_name != u"cython" and not module_name.startswith("cython."):
return node
- module_name = node.rhs.module_name.value
- as_name = node.lhs.name
-
- node = Nodes.CImportStatNode(node.pos,
- module_name = module_name,
- as_name = as_name)
+ node = Nodes.CImportStatNode(node.pos, module_name=module_name, as_name=node.lhs.name)
node = self.visit_CImportStatNode(node)
else:
self.visitchildren(node)
@@ -840,16 +1093,35 @@ class InterpretCompilerDirectives(CythonTransform):
return node
def visit_NameNode(self, node):
+ if node.annotation:
+ self.visitchild(node, 'annotation')
if node.name in self.cython_module_names:
node.is_cython_module = True
else:
directive = self.directive_names.get(node.name)
if directive is not None:
node.cython_attribute = directive
+ if node.as_cython_attribute() == "compiled":
+ return ExprNodes.BoolNode(node.pos, value=True) # replace early so unused branches can be dropped
+ # before they have a chance to cause compile-errors
+ return node
+
+ def visit_AttributeNode(self, node):
+ self.visitchildren(node)
+ if node.as_cython_attribute() == "compiled":
+ return ExprNodes.BoolNode(node.pos, value=True) # replace early so unused branches can be dropped
+ # before they have a chance to cause compile-errors
+ return node
+
+ def visit_AnnotationNode(self, node):
+ # for most transforms annotations are left unvisited (because they're unevaluated)
+ # however, it is important to pick up compiler directives from them
+ if node.expr:
+ self.visit(node.expr)
return node
def visit_NewExprNode(self, node):
- self.visit(node.cppclass)
+ self.visitchild(node, 'cppclass')
self.visitchildren(node)
return node
@@ -858,7 +1130,7 @@ class InterpretCompilerDirectives(CythonTransform):
# decorator), returns a list of (directivename, value) pairs.
# Otherwise, returns None
if isinstance(node, ExprNodes.CallNode):
- self.visit(node.function)
+ self.visitchild(node, 'function')
optname = node.function.as_cython_attribute()
if optname:
directivetype = Options.directive_types.get(optname)
@@ -890,7 +1162,7 @@ class InterpretCompilerDirectives(CythonTransform):
if directivetype is bool:
arg = ExprNodes.BoolNode(node.pos, value=True)
return [self.try_to_parse_directive(optname, [arg], None, node.pos)]
- elif directivetype is None:
+ elif directivetype is None or directivetype is Options.DEFER_ANALYSIS_OF_ARGUMENTS:
return [(optname, None)]
else:
raise PostParseError(
@@ -945,7 +1217,7 @@ class InterpretCompilerDirectives(CythonTransform):
if len(args) != 0:
raise PostParseError(pos,
'The %s directive takes no prepositional arguments' % optname)
- return optname, dict([(key.value, value) for key, value in kwds.key_value_pairs])
+ return optname, kwds.as_python_dict()
elif directivetype is list:
if kwds and len(kwds.key_value_pairs) != 0:
raise PostParseError(pos,
@@ -957,21 +1229,42 @@ class InterpretCompilerDirectives(CythonTransform):
raise PostParseError(pos,
'The %s directive takes one compile-time string argument' % optname)
return (optname, directivetype(optname, str(args[0].value)))
+ elif directivetype is Options.DEFER_ANALYSIS_OF_ARGUMENTS:
+ # signal to pass things on without processing
+ return (optname, (args, kwds.as_python_dict() if kwds else {}))
else:
assert False
- def visit_with_directives(self, node, directives):
+ def visit_with_directives(self, node, directives, contents_directives):
+ # contents_directives may be None
if not directives:
+ assert not contents_directives
return self.visit_Node(node)
old_directives = self.directives
- new_directives = dict(old_directives)
- new_directives.update(directives)
+ new_directives = Options.copy_inherited_directives(old_directives, **directives)
+ if contents_directives is not None:
+ new_contents_directives = Options.copy_inherited_directives(
+ old_directives, **contents_directives)
+ else:
+ new_contents_directives = new_directives
if new_directives == old_directives:
return self.visit_Node(node)
self.directives = new_directives
+ if (contents_directives is not None and
+ new_contents_directives != new_directives):
+ # we need to wrap the node body in a compiler directives node
+ node.body = Nodes.StatListNode(
+ node.body.pos,
+ stats=[
+ Nodes.CompilerDirectivesNode(
+ node.body.pos,
+ directives=new_contents_directives,
+ body=node.body)
+ ]
+ )
retbody = self.visit_Node(node)
self.directives = old_directives
@@ -980,13 +1273,14 @@ class InterpretCompilerDirectives(CythonTransform):
return Nodes.CompilerDirectivesNode(
pos=retbody.pos, body=retbody, directives=new_directives)
+
# Handle decorators
def visit_FuncDefNode(self, node):
- directives = self._extract_directives(node, 'function')
- return self.visit_with_directives(node, directives)
+ directives, contents_directives = self._extract_directives(node, 'function')
+ return self.visit_with_directives(node, directives, contents_directives)
def visit_CVarDefNode(self, node):
- directives = self._extract_directives(node, 'function')
+ directives, _ = self._extract_directives(node, 'function')
for name, value in directives.items():
if name == 'locals':
node.directive_locals = value
@@ -995,27 +1289,34 @@ class InterpretCompilerDirectives(CythonTransform):
node.pos,
"Cdef functions can only take cython.locals(), "
"staticmethod, or final decorators, got %s." % name))
- return self.visit_with_directives(node, directives)
+ return self.visit_with_directives(node, directives, contents_directives=None)
def visit_CClassDefNode(self, node):
- directives = self._extract_directives(node, 'cclass')
- return self.visit_with_directives(node, directives)
+ directives, contents_directives = self._extract_directives(node, 'cclass')
+ return self.visit_with_directives(node, directives, contents_directives)
def visit_CppClassNode(self, node):
- directives = self._extract_directives(node, 'cppclass')
- return self.visit_with_directives(node, directives)
+ directives, contents_directives = self._extract_directives(node, 'cppclass')
+ return self.visit_with_directives(node, directives, contents_directives)
def visit_PyClassDefNode(self, node):
- directives = self._extract_directives(node, 'class')
- return self.visit_with_directives(node, directives)
+ directives, contents_directives = self._extract_directives(node, 'class')
+ return self.visit_with_directives(node, directives, contents_directives)
def _extract_directives(self, node, scope_name):
+ """
+ Returns two dicts - directives applied to this function/class
+ and directives applied to its contents. They aren't always the
+ same (since e.g. cfunc should not be applied to inner functions)
+ """
if not node.decorators:
- return {}
+ return {}, {}
# Split the decorators into two lists -- real decorators and directives
directives = []
realdecs = []
both = []
+ current_opt_dict = dict(self.directives)
+ missing = object()
# Decorators coming first take precedence.
for dec in node.decorators[::-1]:
new_directives = self.try_to_parse_directives(dec.decorator)
@@ -1023,8 +1324,14 @@ class InterpretCompilerDirectives(CythonTransform):
for directive in new_directives:
if self.check_directive_scope(node.pos, directive[0], scope_name):
name, value = directive
- if self.directives.get(name, object()) != value:
+ if current_opt_dict.get(name, missing) != value:
+ if name == 'cfunc' and 'ufunc' in current_opt_dict:
+ error(dec.pos, "Cannot apply @cfunc to @ufunc, please reverse the decorators.")
directives.append(directive)
+ current_opt_dict[name] = value
+ else:
+ warning(dec.pos, "Directive does not change previous value (%s%s)" % (
+ name, '=%r' % value if value is not None else ''))
if directive[0] == 'staticmethod':
both.append(dec)
# Adapt scope type based on decorators that change it.
@@ -1033,13 +1340,21 @@ class InterpretCompilerDirectives(CythonTransform):
else:
realdecs.append(dec)
if realdecs and (scope_name == 'cclass' or
- isinstance(node, (Nodes.CFuncDefNode, Nodes.CClassDefNode, Nodes.CVarDefNode))):
+ isinstance(node, (Nodes.CClassDefNode, Nodes.CVarDefNode))):
+ for realdec in realdecs:
+ dec_pos = realdec.pos
+ realdec = realdec.decorator
+ if ((realdec.is_name and realdec.name == "dataclass") or
+ (realdec.is_attribute and realdec.attribute == "dataclass")):
+ error(dec_pos,
+ "Use '@cython.dataclasses.dataclass' on cdef classes to create a dataclass")
+ # Note - arbitrary C function decorators are caught later in DecoratorTransform
raise PostParseError(realdecs[0].pos, "Cdef functions/classes cannot take arbitrary decorators.")
node.decorators = realdecs[::-1] + both[::-1]
# merge or override repeated directives
optdict = {}
- for directive in directives:
- name, value = directive
+ contents_optdict = {}
+ for name, value in directives:
if name in optdict:
old_value = optdict[name]
# keywords and arg lists can be merged, everything
@@ -1052,7 +1367,9 @@ class InterpretCompilerDirectives(CythonTransform):
optdict[name] = value
else:
optdict[name] = value
- return optdict
+ if name not in Options.immediate_decorator_directives:
+ contents_optdict[name] = value
+ return optdict, contents_optdict
# Handle with-statements
def visit_WithStatNode(self, node):
@@ -1071,7 +1388,7 @@ class InterpretCompilerDirectives(CythonTransform):
if self.check_directive_scope(node.pos, name, 'with statement'):
directive_dict[name] = value
if directive_dict:
- return self.visit_with_directives(node.body, directive_dict)
+ return self.visit_with_directives(node.body, directive_dict, contents_directives=None)
return self.visit_Node(node)
@@ -1161,7 +1478,7 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations):
return node
def visit_CallNode(self, node):
- self.visit(node.function)
+ self.visitchild(node, 'function')
if not self.parallel_directive:
self.visitchildren(node, exclude=('function',))
return node
@@ -1194,7 +1511,7 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations):
"Nested parallel with blocks are disallowed")
self.state = 'parallel with'
- body = self.visit(node.body)
+ body = self.visitchild(node, 'body')
self.state = None
newnode.body = body
@@ -1210,13 +1527,13 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations):
error(node.pos, "The parallel directive must be called")
return None
- node.body = self.visit(node.body)
+ self.visitchild(node, 'body')
return node
def visit_ForInStatNode(self, node):
"Rewrite 'for i in cython.parallel.prange(...):'"
- self.visit(node.iterator)
- self.visit(node.target)
+ self.visitchild(node, 'iterator')
+ self.visitchild(node, 'target')
in_prange = isinstance(node.iterator.sequence,
Nodes.ParallelRangeNode)
@@ -1239,9 +1556,9 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations):
self.state = 'prange'
- self.visit(node.body)
+ self.visitchild(node, 'body')
self.state = previous_state
- self.visit(node.else_clause)
+ self.visitchild(node, 'else_clause')
return node
def visit(self, node):
@@ -1250,12 +1567,13 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations):
return super(ParallelRangeTransform, self).visit(node)
-class WithTransform(CythonTransform, SkipDeclarations):
+class WithTransform(VisitorTransform, SkipDeclarations):
def visit_WithStatNode(self, node):
self.visitchildren(node, 'body')
pos = node.pos
is_async = node.is_async
body, target, manager = node.body, node.target, node.manager
+ manager = node.manager = ExprNodes.ProxyNode(manager)
node.enter_call = ExprNodes.SimpleCallNode(
pos, function=ExprNodes.AttributeNode(
pos, obj=ExprNodes.CloneNode(manager),
@@ -1316,6 +1634,130 @@ class WithTransform(CythonTransform, SkipDeclarations):
# With statements are never inside expressions.
return node
+ visit_Node = VisitorTransform.recurse_to_children
+
+
+class _GeneratorExpressionArgumentsMarker(TreeVisitor, SkipDeclarations):
+ # called from "MarkClosureVisitor"
+ def __init__(self, gen_expr):
+ super(_GeneratorExpressionArgumentsMarker, self).__init__()
+ self.gen_expr = gen_expr
+
+ def visit_ExprNode(self, node):
+ if not node.is_literal:
+ # Don't bother tagging literal nodes
+ assert (not node.generator_arg_tag) # nobody has tagged this first
+ node.generator_arg_tag = self.gen_expr
+ self.visitchildren(node)
+
+ def visit_Node(self, node):
+ # We're only interested in the expressions that make up the iterator sequence,
+ # so don't go beyond ExprNodes (e.g. into ForFromStatNode).
+ return
+
+ def visit_GeneratorExpressionNode(self, node):
+ node.generator_arg_tag = self.gen_expr
+ # don't visit children, can't handle overlapping tags
+ # (and assume generator expressions don't end up optimized out in a way
+ # that would require overlapping tags)
+
+
+class _HandleGeneratorArguments(VisitorTransform, SkipDeclarations):
+ # used from within CreateClosureClasses
+
+ def __call__(self, node):
+ from . import Visitor
+ assert isinstance(node, ExprNodes.GeneratorExpressionNode)
+ self.gen_node = node
+
+ self.args = list(node.def_node.args)
+ self.call_parameters = list(node.call_parameters)
+ self.tag_count = 0
+ self.substitutions = {}
+
+ self.visitchildren(node)
+
+ for k, v in self.substitutions.items():
+ # doing another search for replacements here (at the end) allows us to sweep up
+ # CloneNodes too (which are often generated by the optimizer)
+ # (it could arguably be done more efficiently with a single traversal though)
+ Visitor.recursively_replace_node(node, k, v)
+
+ node.def_node.args = self.args
+ node.call_parameters = self.call_parameters
+ return node
+
+ def visit_GeneratorExpressionNode(self, node):
+ # a generator can also be substituted itself, so handle that case
+ new_node = self._handle_ExprNode(node, do_visit_children=False)
+ # However do not traverse into it. A new _HandleGeneratorArguments visitor will be used
+ # elsewhere to do that.
+ return node
+
+ def _handle_ExprNode(self, node, do_visit_children):
+ if (node.generator_arg_tag is not None and self.gen_node is not None and
+ self.gen_node == node.generator_arg_tag):
+ pos = node.pos
+ # The reason for using ".x" as the name is that this is how CPython
+ # tracks internal variables in loops (e.g.
+ # { locals() for v in range(10) }
+ # will produce "v" and ".0"). We don't replicate this behaviour completely
+ # but use it as a starting point
+ name_source = self.tag_count
+ self.tag_count += 1
+ name = EncodedString(".{0}".format(name_source))
+ def_node = self.gen_node.def_node
+ if not def_node.local_scope.lookup_here(name):
+ from . import Symtab
+ cname = EncodedString(Naming.genexpr_arg_prefix + Symtab.punycodify_name(str(name_source)))
+ name_decl = Nodes.CNameDeclaratorNode(pos=pos, name=name)
+ type = node.type
+ if type.is_reference and not type.is_fake_reference:
+ # It isn't obvious whether the right thing to do would be to capture by reference or by
+ # value (C++ itself doesn't know either for lambda functions and forces a choice).
+ # However, capture by reference involves converting to FakeReference which would require
+ # re-analysing AttributeNodes. Therefore I've picked capture-by-value out of convenience
+ # TODO - could probably be optimized by making the arg a reference but the closure not
+ # (see https://github.com/cython/cython/issues/2468)
+ type = type.ref_base_type
+
+ name_decl.type = type
+ new_arg = Nodes.CArgDeclNode(pos=pos, declarator=name_decl,
+ base_type=None, default=None, annotation=None)
+ new_arg.name = name_decl.name
+ new_arg.type = type
+
+ self.args.append(new_arg)
+ node.generator_arg_tag = None # avoid the possibility of this being caught again
+ self.call_parameters.append(node)
+ new_arg.entry = def_node.declare_argument(def_node.local_scope, new_arg)
+ new_arg.entry.cname = cname
+ new_arg.entry.in_closure = True
+
+ if do_visit_children:
+ # now visit the Nodes's children (but remove self.gen_node to not to further
+ # argument substitution)
+ gen_node, self.gen_node = self.gen_node, None
+ self.visitchildren(node)
+ self.gen_node = gen_node
+
+ # replace the node inside the generator with a looked-up name
+ # (initialized_check can safely be False because the source variable will be checked
+ # before it is captured if the check is required)
+ name_node = ExprNodes.NameNode(pos, name=name, initialized_check=False)
+ name_node.entry = self.gen_node.def_node.gbody.local_scope.lookup(name_node.name)
+ name_node.type = name_node.entry.type
+ self.substitutions[node] = name_node
+ return name_node
+ if do_visit_children:
+ self.visitchildren(node)
+ return node
+
+ def visit_ExprNode(self, node):
+ return self._handle_ExprNode(node, True)
+
+ visit_Node = VisitorTransform.recurse_to_children
+
class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations):
"""
@@ -1328,16 +1770,16 @@ class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations):
_properties = None
_map_property_attribute = {
- 'getter': '__get__',
- 'setter': '__set__',
- 'deleter': '__del__',
+ 'getter': EncodedString('__get__'),
+ 'setter': EncodedString('__set__'),
+ 'deleter': EncodedString('__del__'),
}.get
def visit_CClassDefNode(self, node):
if self._properties is None:
self._properties = []
self._properties.append({})
- super(DecoratorTransform, self).visit_CClassDefNode(node)
+ node = super(DecoratorTransform, self).visit_CClassDefNode(node)
self._properties.pop()
return node
@@ -1347,6 +1789,32 @@ class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations):
warning(node.pos, "'property %s:' syntax is deprecated, use '@property'" % node.name, level)
return node
+ def visit_CFuncDefNode(self, node):
+ node = self.visit_FuncDefNode(node)
+ if not node.decorators:
+ return node
+ elif self.scope_type != 'cclass' or self.scope_node.visibility != "extern":
+ # at the moment cdef functions are very restricted in what decorators they can take
+ # so it's simple to test for the small number of allowed decorators....
+ if not (len(node.decorators) == 1 and node.decorators[0].decorator.is_name and
+ node.decorators[0].decorator.name == "staticmethod"):
+ error(node.decorators[0].pos, "Cdef functions cannot take arbitrary decorators.")
+ return node
+
+ ret_node = node
+ decorator_node = self._find_property_decorator(node)
+ if decorator_node:
+ if decorator_node.decorator.is_name:
+ name = node.declared_name()
+ if name:
+ ret_node = self._add_property(node, name, decorator_node)
+ else:
+ error(decorator_node.pos, "C property decorator can only be @property")
+
+ if node.decorators:
+ return self._reject_decorated_property(node, node.decorators[0])
+ return ret_node
+
def visit_DefNode(self, node):
scope_type = self.scope_type
node = self.visit_FuncDefNode(node)
@@ -1354,28 +1822,12 @@ class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations):
return node
# transform @property decorators
- properties = self._properties[-1]
- for decorator_node in node.decorators[::-1]:
+ decorator_node = self._find_property_decorator(node)
+ if decorator_node is not None:
decorator = decorator_node.decorator
- if decorator.is_name and decorator.name == 'property':
- if len(node.decorators) > 1:
- return self._reject_decorated_property(node, decorator_node)
- name = node.name
- node.name = EncodedString('__get__')
- node.decorators.remove(decorator_node)
- stat_list = [node]
- if name in properties:
- prop = properties[name]
- prop.pos = node.pos
- prop.doc = node.doc
- prop.body.stats = stat_list
- return []
- prop = Nodes.PropertyNode(node.pos, name=name)
- prop.doc = node.doc
- prop.body = Nodes.StatListNode(node.pos, stats=stat_list)
- properties[name] = prop
- return [prop]
- elif decorator.is_attribute and decorator.obj.name in properties:
+ if decorator.is_name:
+ return self._add_property(node, node.name, decorator_node)
+ else:
handler_name = self._map_property_attribute(decorator.attribute)
if handler_name:
if decorator.obj.name != node.name:
@@ -1386,7 +1838,7 @@ class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations):
elif len(node.decorators) > 1:
return self._reject_decorated_property(node, decorator_node)
else:
- return self._add_to_property(properties, node, handler_name, decorator_node)
+ return self._add_to_property(node, handler_name, decorator_node)
# we clear node.decorators, so we need to set the
# is_staticmethod/is_classmethod attributes now
@@ -1401,6 +1853,18 @@ class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations):
node.decorators = None
return self.chain_decorators(node, decs, node.name)
+ def _find_property_decorator(self, node):
+ properties = self._properties[-1]
+ for decorator_node in node.decorators[::-1]:
+ decorator = decorator_node.decorator
+ if decorator.is_name and decorator.name == 'property':
+ # @property
+ return decorator_node
+ elif decorator.is_attribute and decorator.obj.name in properties:
+ # @prop.setter etc.
+ return decorator_node
+ return None
+
@staticmethod
def _reject_decorated_property(node, decorator_node):
# restrict transformation to outermost decorator as wrapped properties will probably not work
@@ -1409,9 +1873,42 @@ class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations):
error(deco.pos, "Property methods with additional decorators are not supported")
return node
- @staticmethod
- def _add_to_property(properties, node, name, decorator):
+ def _add_property(self, node, name, decorator_node):
+ if len(node.decorators) > 1:
+ return self._reject_decorated_property(node, decorator_node)
+ node.decorators.remove(decorator_node)
+ properties = self._properties[-1]
+ is_cproperty = isinstance(node, Nodes.CFuncDefNode)
+ body = Nodes.StatListNode(node.pos, stats=[node])
+ if is_cproperty:
+ if name in properties:
+ error(node.pos, "C property redeclared")
+ if 'inline' not in node.modifiers:
+ error(node.pos, "C property method must be declared 'inline'")
+ prop = Nodes.CPropertyNode(node.pos, doc=node.doc, name=name, body=body)
+ elif name in properties:
+ prop = properties[name]
+ if prop.is_cproperty:
+ error(node.pos, "C property redeclared")
+ else:
+ node.name = EncodedString("__get__")
+ prop.pos = node.pos
+ prop.doc = node.doc
+ prop.body.stats = [node]
+ return None
+ else:
+ node.name = EncodedString("__get__")
+ prop = Nodes.PropertyNode(
+ node.pos, name=name, doc=node.doc, body=body)
+ properties[name] = prop
+ return prop
+
+ def _add_to_property(self, node, name, decorator):
+ properties = self._properties[-1]
prop = properties[node.name]
+ if prop.is_cproperty:
+ error(node.pos, "C property redeclared")
+ return None
node.name = name
node.decorators.remove(decorator)
stats = prop.body.stats
@@ -1421,7 +1918,7 @@ class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations):
break
else:
stats.append(node)
- return []
+ return None
@staticmethod
def chain_decorators(node, decorators, name):
@@ -1499,6 +1996,10 @@ class CnameDirectivesTransform(CythonTransform, SkipDeclarations):
class ForwardDeclareTypes(CythonTransform):
+ """
+ Declare all global cdef names that we allow referencing in other places,
+ before declaring everything (else) in source code order.
+ """
def visit_CompilerDirectivesNode(self, node):
env = self.module_scope
@@ -1542,6 +2043,14 @@ class ForwardDeclareTypes(CythonTransform):
entry.type.get_all_specialized_function_types()
return node
+ def visit_FuncDefNode(self, node):
+ # no traversal needed
+ return node
+
+ def visit_PyClassDefNode(self, node):
+ # no traversal needed
+ return node
+
class AnalyseDeclarationsTransform(EnvTransform):
@@ -1622,6 +2131,9 @@ if VALUE is not None:
def visit_CClassDefNode(self, node):
node = self.visit_ClassDefNode(node)
+ if node.scope and 'dataclasses.dataclass' in node.scope.directives:
+ from .Dataclass import handle_cclass_dataclass
+ handle_cclass_dataclass(node, node.scope.directives['dataclasses.dataclass'], self)
if node.scope and node.scope.implemented and node.body:
stats = []
for entry in node.scope.var_entries:
@@ -1633,8 +2145,8 @@ if VALUE is not None:
if stats:
node.body.stats += stats
if (node.visibility != 'extern'
- and not node.scope.lookup('__reduce__')
- and not node.scope.lookup('__reduce_ex__')):
+ and not node.scope.lookup('__reduce__')
+ and not node.scope.lookup('__reduce_ex__')):
self._inject_pickle_methods(node)
return node
@@ -1688,9 +2200,9 @@ if VALUE is not None:
pickle_func = TreeFragment(u"""
def __reduce_cython__(self):
- raise TypeError("%(msg)s")
+ raise TypeError, "%(msg)s"
def __setstate_cython__(self, __pyx_state):
- raise TypeError("%(msg)s")
+ raise TypeError, "%(msg)s"
""" % {'msg': msg},
level='c_class', pipeline=[NormalizeTree(None)]).substitute({})
pickle_func.analyse_declarations(node.scope)
@@ -1702,10 +2214,11 @@ if VALUE is not None:
if not e.type.is_pyobject:
e.type.create_to_py_utility_code(env)
e.type.create_from_py_utility_code(env)
+
all_members_names = [e.name for e in all_members]
checksums = _calculate_pickle_checksums(all_members_names)
- unpickle_func_name = '__pyx_unpickle_%s' % node.class_name
+ unpickle_func_name = '__pyx_unpickle_%s' % node.punycode_class_name
# TODO(robertwb): Move the state into the third argument
# so it can be pickled *after* self is memoized.
@@ -1715,7 +2228,7 @@ if VALUE is not None:
cdef object __pyx_result
if __pyx_checksum not in %(checksums)s:
from pickle import PickleError as __pyx_PickleError
- raise __pyx_PickleError("Incompatible checksums (0x%%x vs %(checksums)s = (%(members)s))" %% __pyx_checksum)
+ raise __pyx_PickleError, "Incompatible checksums (0x%%x vs %(checksums)s = (%(members)s))" %% __pyx_checksum
__pyx_result = %(class_name)s.__new__(__pyx_type)
if __pyx_state is not None:
%(unpickle_func_name)s__set_state(<%(class_name)s> __pyx_result, __pyx_state)
@@ -1783,8 +2296,8 @@ if VALUE is not None:
for decorator in old_decorators:
func = decorator.decorator
if (not func.is_name or
- func.name not in ('staticmethod', 'classmethod') or
- env.lookup_here(func.name)):
+ func.name not in ('staticmethod', 'classmethod') or
+ env.lookup_here(func.name)):
# not a static or classmethod
decorators.append(decorator)
@@ -1802,8 +2315,10 @@ if VALUE is not None:
"Handle def or cpdef fused functions"
# Create PyCFunction nodes for each specialization
node.stats.insert(0, node.py_func)
- node.py_func = self.visit(node.py_func)
+ self.visitchild(node, 'py_func')
node.update_fused_defnode_entry(env)
+ # For the moment, fused functions do not support METH_FASTCALL
+ node.py_func.entry.signature.use_fastcall = False
pycfunc = ExprNodes.PyCFunctionNode.from_defnode(node.py_func, binding=True)
pycfunc = ExprNodes.ProxyNode(pycfunc.coerce_to_temp(env))
node.resulting_fused_function = pycfunc
@@ -1846,19 +2361,6 @@ if VALUE is not None:
return node
- def _handle_nogil_cleanup(self, lenv, node):
- "Handle cleanup for 'with gil' blocks in nogil functions."
- if lenv.nogil and lenv.has_with_gil_block:
- # Acquire the GIL for cleanup in 'nogil' functions, by wrapping
- # the entire function body in try/finally.
- # The corresponding release will be taken care of by
- # Nodes.FuncDefNode.generate_function_definitions()
- node.body = Nodes.NogilTryFinallyStatNode(
- node.body.pos,
- body=node.body,
- finally_clause=Nodes.EnsureGILNode(node.body.pos),
- finally_except_clause=Nodes.EnsureGILNode(node.body.pos))
-
def _handle_fused(self, node):
if node.is_generator and node.has_fused_arguments:
node.has_fused_arguments = False
@@ -1890,6 +2392,8 @@ if VALUE is not None:
for var, type_node in node.directive_locals.items():
if not lenv.lookup_here(var): # don't redeclare args
type = type_node.analyse_as_type(lenv)
+ if type and type.is_fused and lenv.fused_to_specific:
+ type = type.specialize(lenv.fused_to_specific)
if type:
lenv.declare_var(var, type, type_node.pos)
else:
@@ -1899,17 +2403,18 @@ if VALUE is not None:
node = self._create_fused_function(env, node)
else:
node.body.analyse_declarations(lenv)
- self._handle_nogil_cleanup(lenv, node)
self._super_visit_FuncDefNode(node)
self.seen_vars_stack.pop()
+
+ if "ufunc" in lenv.directives:
+ from . import UFuncs
+ return UFuncs.convert_to_ufunc(node)
return node
def visit_DefNode(self, node):
node = self.visit_FuncDefNode(node)
env = self.current_env()
- if isinstance(node, Nodes.DefNode) and node.is_wrapper:
- env = env.parent_scope
if (not isinstance(node, Nodes.DefNode) or
node.fused_py_func or node.is_generator_body or
not node.needs_assignment_synthesis(env)):
@@ -1959,11 +2464,17 @@ if VALUE is not None:
assmt.analyse_declarations(env)
return assmt
+ def visit_func_outer_attrs(self, node):
+ # any names in the outer attrs should not be looked up in the function "seen_vars_stack"
+ stack = self.seen_vars_stack.pop()
+ super(AnalyseDeclarationsTransform, self).visit_func_outer_attrs(node)
+ self.seen_vars_stack.append(stack)
+
def visit_ScopedExprNode(self, node):
env = self.current_env()
node.analyse_declarations(env)
# the node may or may not have a local scope
- if node.has_local_scope:
+ if node.expr_scope:
self.seen_vars_stack.append(set(self.seen_vars_stack[-1]))
self.enter_scope(node, node.expr_scope)
node.analyse_scoped_declarations(node.expr_scope)
@@ -1971,6 +2482,7 @@ if VALUE is not None:
self.exit_scope()
self.seen_vars_stack.pop()
else:
+
node.analyse_scoped_declarations(env)
self.visitchildren(node)
return node
@@ -1993,7 +2505,7 @@ if VALUE is not None:
# (so it can't happen later).
# Note that we don't return the original node, as it is
# never used after this phase.
- if True: # private (default)
+ if True: # private (default)
return None
self_value = ExprNodes.AttributeNode(
@@ -2085,8 +2597,8 @@ if VALUE is not None:
if node.name in self.seen_vars_stack[-1]:
entry = self.current_env().lookup(node.name)
if (entry is None or entry.visibility != 'extern'
- and not entry.scope.is_c_class_scope):
- warning(node.pos, "cdef variable '%s' declared after it is used" % node.name, 2)
+ and not entry.scope.is_c_class_scope):
+ error(node.pos, "cdef variable '%s' declared after it is used" % node.name)
self.visitchildren(node)
return node
@@ -2096,13 +2608,12 @@ if VALUE is not None:
return None
def visit_CnameDecoratorNode(self, node):
- child_node = self.visit(node.node)
+ child_node = self.visitchild(node, 'node')
if not child_node:
return None
- if type(child_node) is list: # Assignment synthesized
- node.child_node = child_node[0]
+ if type(child_node) is list: # Assignment synthesized
+ node.node = child_node[0]
return [node] + child_node[1:]
- node.node = child_node
return node
def create_Property(self, entry):
@@ -2122,6 +2633,11 @@ if VALUE is not None:
property.doc = entry.doc
return property
+ def visit_AssignmentExpressionNode(self, node):
+ self.visitchildren(node)
+ node.analyse_declarations(self.current_env())
+ return node
+
def _calculate_pickle_checksums(member_names):
# Cython 0.x used MD5 for the checksum, which a few Python installations remove for security reasons.
@@ -2130,7 +2646,7 @@ def _calculate_pickle_checksums(member_names):
member_names_string = ' '.join(member_names).encode('utf-8')
hash_kwargs = {'usedforsecurity': False} if sys.version_info >= (3, 9) else {}
checksums = []
- for algo_name in ['md5', 'sha256', 'sha1']:
+ for algo_name in ['sha256', 'sha1', 'md5']:
try:
mkchecksum = getattr(hashlib, algo_name)
checksum = mkchecksum(member_names_string, **hash_kwargs).hexdigest()
@@ -2219,7 +2735,7 @@ class CalculateQualifiedNamesTransform(EnvTransform):
def visit_ClassDefNode(self, node):
orig_qualified_name = self.qualified_name[:]
entry = (getattr(node, 'entry', None) or # PyClass
- self.current_env().lookup_here(node.name)) # CClass
+ self.current_env().lookup_here(node.target.name)) # CClass
self._append_entry(entry)
self._super_visit_ClassDefNode(node)
self.qualified_name = orig_qualified_name
@@ -2328,8 +2844,8 @@ class ExpandInplaceOperators(EnvTransform):
operand2 = rhs,
inplace=True)
# Manually analyse types for new node.
- lhs.analyse_target_types(env)
- dup.analyse_types(env)
+ lhs = lhs.analyse_target_types(env)
+ dup.analyse_types(env) # FIXME: no need to reanalyse the copy, right?
binop.analyse_operation(env)
node = Nodes.SingleAssignmentNode(
node.pos,
@@ -2379,13 +2895,15 @@ class AdjustDefByDirectives(CythonTransform, SkipDeclarations):
return_type_node = self.directives.get('returns')
if return_type_node is None and self.directives['annotation_typing']:
return_type_node = node.return_type_annotation
- # for Python anntations, prefer safe exception handling by default
+ # for Python annotations, prefer safe exception handling by default
if return_type_node is not None and except_val is None:
except_val = (None, True) # except *
elif except_val is None:
- # backward compatible default: no exception check
- except_val = (None, False)
+ # backward compatible default: no exception check, unless there's also a "@returns" declaration
+ except_val = (None, True if return_type_node else False)
if 'ccall' in self.directives:
+ if 'cfunc' in self.directives:
+ error(node.pos, "cfunc and ccall directives cannot be combined")
node = node.as_cfunction(
overridable=True, modifiers=modifiers, nogil=nogil,
returns=return_type_node, except_val=except_val)
@@ -2437,8 +2955,6 @@ class AlignFunctionDefinitions(CythonTransform):
def visit_ModuleNode(self, node):
self.scope = node.scope
- self.directives = node.directives
- self.imported_names = set() # hack, see visit_FromImportStatNode()
self.visitchildren(node)
return node
@@ -2476,15 +2992,45 @@ class AlignFunctionDefinitions(CythonTransform):
error(pxd_def.pos, "previous declaration here")
return None
node = node.as_cfunction(pxd_def)
- elif (self.scope.is_module_scope and self.directives['auto_cpdef']
- and not node.name in self.imported_names
- and node.is_cdef_func_compatible()):
- # FIXME: cpdef-ing should be done in analyse_declarations()
- node = node.as_cfunction(scope=self.scope)
# Enable this when nested cdef functions are allowed.
# self.visitchildren(node)
return node
+ def visit_ExprNode(self, node):
+ # ignore lambdas and everything else that appears in expressions
+ return node
+
+
+class AutoCpdefFunctionDefinitions(CythonTransform):
+
+ def visit_ModuleNode(self, node):
+ self.directives = node.directives
+ self.imported_names = set() # hack, see visit_FromImportStatNode()
+ self.scope = node.scope
+ self.visitchildren(node)
+ return node
+
+ def visit_DefNode(self, node):
+ if (self.scope.is_module_scope and self.directives['auto_cpdef']
+ and node.name not in self.imported_names
+ and node.is_cdef_func_compatible()):
+ # FIXME: cpdef-ing should be done in analyse_declarations()
+ node = node.as_cfunction(scope=self.scope)
+ return node
+
+ def visit_CClassDefNode(self, node, pxd_def=None):
+ if pxd_def is None:
+ pxd_def = self.scope.lookup(node.class_name)
+ if pxd_def:
+ if not pxd_def.defined_in_pxd:
+ return node
+ outer_scope = self.scope
+ self.scope = pxd_def.type.scope
+ self.visitchildren(node)
+ if pxd_def:
+ self.scope = outer_scope
+ return node
+
def visit_FromImportStatNode(self, node):
# hack to prevent conditional import fallback functions from
# being cdpef-ed (global Python variables currently conflict
@@ -2504,8 +3050,7 @@ class RemoveUnreachableCode(CythonTransform):
if not self.current_directives['remove_unreachable']:
return node
self.visitchildren(node)
- for idx, stat in enumerate(node.stats):
- idx += 1
+ for idx, stat in enumerate(node.stats, 1):
if stat.is_terminator:
if idx < len(node.stats):
if self.current_directives['warn.unreachable']:
@@ -2604,6 +3149,8 @@ class YieldNodeCollector(TreeVisitor):
class MarkClosureVisitor(CythonTransform):
+ # In addition to marking closures this is also responsible to finding parts of the
+ # generator iterable and marking them
def visit_ModuleNode(self, node):
self.needs_closure = False
@@ -2649,7 +3196,8 @@ class MarkClosureVisitor(CythonTransform):
star_arg=node.star_arg, starstar_arg=node.starstar_arg,
doc=node.doc, decorators=node.decorators,
gbody=gbody, lambda_name=node.lambda_name,
- return_type_annotation=node.return_type_annotation)
+ return_type_annotation=node.return_type_annotation,
+ is_generator_expression=node.is_generator_expression)
return coroutine
def visit_CFuncDefNode(self, node):
@@ -2673,6 +3221,19 @@ class MarkClosureVisitor(CythonTransform):
self.needs_closure = True
return node
+ def visit_GeneratorExpressionNode(self, node):
+ node = self.visit_LambdaNode(node)
+ if not isinstance(node.loop, Nodes._ForInStatNode):
+ # Possibly should handle ForFromStatNode
+ # but for now do nothing
+ return node
+ itseq = node.loop.iterator.sequence
+ # literals do not need replacing with an argument
+ if itseq.is_literal:
+ return node
+ _GeneratorExpressionArgumentsMarker(node).visit(itseq)
+ return node
+
class CreateClosureClasses(CythonTransform):
# Output closure classes in module scope for all functions
@@ -2726,7 +3287,7 @@ class CreateClosureClasses(CythonTransform):
if not node.py_cfunc_node:
raise InternalError("DefNode does not have assignment node")
inner_node = node.py_cfunc_node
- inner_node.needs_self_code = False
+ inner_node.needs_closure_code = False
node.needs_outer_scope = False
if node.is_generator:
@@ -2745,6 +3306,7 @@ class CreateClosureClasses(CythonTransform):
as_name = '%s_%s' % (
target_module_scope.next_id(Naming.closure_class_prefix),
node.entry.cname.replace('.','__'))
+ as_name = EncodedString(as_name)
entry = target_module_scope.declare_c_class(
name=as_name, pos=node.pos, defining=True,
@@ -2816,6 +3378,10 @@ class CreateClosureClasses(CythonTransform):
self.visitchildren(node)
return node
+ def visit_GeneratorExpressionNode(self, node):
+ node = _HandleGeneratorArguments()(node)
+ return self.visit_LambdaNode(node)
+
class InjectGilHandling(VisitorTransform, SkipDeclarations):
"""
@@ -2824,20 +3390,20 @@ class InjectGilHandling(VisitorTransform, SkipDeclarations):
Must run before the AnalyseDeclarationsTransform to make sure the GILStatNodes get
set up, parallel sections know that the GIL is acquired inside of them, etc.
"""
- def __call__(self, root):
- self.nogil = False
- return super(InjectGilHandling, self).__call__(root)
+ nogil = False
# special node handling
- def visit_RaiseStatNode(self, node):
- """Allow raising exceptions in nogil sections by wrapping them in a 'with gil' block."""
+ def _inject_gil_in_nogil(self, node):
+ """Allow the (Python statement) node in nogil sections by wrapping it in a 'with gil' block."""
if self.nogil:
node = Nodes.GILStatNode(node.pos, state='gil', body=node)
return node
+ visit_RaiseStatNode = _inject_gil_in_nogil
+ visit_PrintStatNode = _inject_gil_in_nogil # sadly, not the function
+
# further candidates:
- # def visit_AssertStatNode(self, node):
# def visit_ReraiseStatNode(self, node):
# nogil tracking
@@ -2905,6 +3471,7 @@ class GilCheck(VisitorTransform):
self.env_stack.append(node.local_scope)
inner_nogil = node.local_scope.nogil
+ nogil_declarator_only = self.nogil_declarator_only
if inner_nogil:
self.nogil_declarator_only = True
@@ -2913,13 +3480,20 @@ class GilCheck(VisitorTransform):
self._visit_scoped_children(node, inner_nogil)
- # This cannot be nested, so it doesn't need backup/restore
- self.nogil_declarator_only = False
+ # FuncDefNodes can be nested, because a cpdef function contains a def function
+ # inside it. Therefore restore to previous state
+ self.nogil_declarator_only = nogil_declarator_only
self.env_stack.pop()
return node
def visit_GILStatNode(self, node):
+ if node.condition is not None:
+ error(node.condition.pos,
+ "Non-constant condition in a "
+ "`with %s(<condition>)` statement" % node.state)
+ return node
+
if self.nogil and node.nogil_check:
node.nogil_check()
@@ -2933,6 +3507,8 @@ class GilCheck(VisitorTransform):
else:
error(node.pos, "Trying to release the GIL while it was "
"previously released.")
+ if self.nogil_declarator_only:
+ node.scope_gil_state_known = False
if isinstance(node.finally_clause, Nodes.StatListNode):
# The finally clause of the GILStatNode is a GILExitNode,
@@ -2983,6 +3559,12 @@ class GilCheck(VisitorTransform):
self.visitchildren(node)
return node
+ def visit_GILExitNode(self, node):
+ if self.nogil_declarator_only:
+ node.scope_gil_state_known = False
+ self.visitchildren(node)
+ return node
+
def visit_Node(self, node):
if self.env_stack and self.nogil and node.nogil_check:
node.nogil_check(self.env_stack[-1])
@@ -2995,6 +3577,32 @@ class GilCheck(VisitorTransform):
return node
+class CoerceCppTemps(EnvTransform, SkipDeclarations):
+ """
+ For temporary expression that are implemented using std::optional it's necessary the temps are
+ assigned using `__pyx_t_x = value;` but accessed using `something = (*__pyx_t_x)`. This transform
+ inserts a coercion node to take care of this, and runs absolutely last (once nothing else can be
+ inserted into the tree)
+
+ TODO: a possible alternative would be to split ExprNode.result() into ExprNode.rhs_rhs() and ExprNode.lhs_rhs()???
+ """
+ def visit_ModuleNode(self, node):
+ if self.current_env().cpp:
+ # skipping this makes it essentially free for C files
+ self.visitchildren(node)
+ return node
+
+ def visit_ExprNode(self, node):
+ self.visitchildren(node)
+ if (self.current_env().directives['cpp_locals'] and
+ node.is_temp and node.type.is_cpp_class and
+ # Fake references are not replaced with "std::optional()".
+ not node.type.is_fake_reference):
+ node = ExprNodes.CppOptionalTempCoercion(node)
+
+ return node
+
+
class TransformBuiltinMethods(EnvTransform):
"""
Replace Cython's own cython.* builtins by the corresponding tree nodes.
@@ -3017,9 +3625,7 @@ class TransformBuiltinMethods(EnvTransform):
def visit_cython_attribute(self, node):
attribute = node.as_cython_attribute()
if attribute:
- if attribute == u'compiled':
- node = ExprNodes.BoolNode(node.pos, value=True)
- elif attribute == u'__version__':
+ if attribute == u'__version__':
from .. import __version__ as version
node = ExprNodes.StringNode(node.pos, value=EncodedString(version))
elif attribute == u'NULL':
@@ -3064,9 +3670,9 @@ class TransformBuiltinMethods(EnvTransform):
error(self.pos, "Builtin 'vars()' called with wrong number of args, expected 0-1, got %d"
% len(node.args))
if len(node.args) > 0:
- return node # nothing to do
+ return node # nothing to do
return ExprNodes.LocalsExprNode(pos, self.current_scope_node(), lenv)
- else: # dir()
+ else: # dir()
if len(node.args) > 1:
error(self.pos, "Builtin 'dir()' called with wrong number of args, expected 0-1, got %d"
% len(node.args))
@@ -3101,8 +3707,8 @@ class TransformBuiltinMethods(EnvTransform):
def _inject_eval(self, node, func_name):
lenv = self.current_env()
- entry = lenv.lookup_here(func_name)
- if entry or len(node.args) != 1:
+ entry = lenv.lookup(func_name)
+ if len(node.args) != 1 or (entry and not entry.is_builtin):
return node
# Inject globals and locals
node.args.append(ExprNodes.GlobalsExprNode(node.pos))
@@ -3119,8 +3725,7 @@ class TransformBuiltinMethods(EnvTransform):
return node
# Inject no-args super
def_node = self.current_scope_node()
- if (not isinstance(def_node, Nodes.DefNode) or not def_node.args or
- len(self.env_stack) < 2):
+ if not isinstance(def_node, Nodes.DefNode) or not def_node.args or len(self.env_stack) < 2:
return node
class_node, class_scope = self.env_stack[-2]
if class_scope.is_py_class_scope:
@@ -3259,10 +3864,17 @@ class ReplaceFusedTypeChecks(VisitorTransform):
self.visitchildren(node)
return self.transform(node)
+ def visit_GILStatNode(self, node):
+ """
+ Fold constant condition of GILStatNode.
+ """
+ self.visitchildren(node)
+ return self.transform(node)
+
def visit_PrimaryCmpNode(self, node):
with Errors.local_errors(ignore=True):
- type1 = node.operand1.analyse_as_type(self.local_scope)
- type2 = node.operand2.analyse_as_type(self.local_scope)
+ type1 = node.operand1.analyse_as_type(self.local_scope)
+ type2 = node.operand2.analyse_as_type(self.local_scope)
if type1 and type2:
false_node = ExprNodes.BoolNode(node.pos, value=False)
@@ -3398,9 +4010,14 @@ class DebugTransform(CythonTransform):
else:
pf_cname = node.py_func.entry.func_cname
+ # For functions defined using def, cname will be pyfunc_cname=__pyx_pf_*
+ # For functions defined using cpdef or cdef, cname will be func_cname=__pyx_f_*
+ # In all cases, cname will be the name of the function containing the actual code
+ cname = node.entry.pyfunc_cname or node.entry.func_cname
+
attrs = dict(
name=node.entry.name or getattr(node, 'name', '<unknown>'),
- cname=node.entry.func_cname,
+ cname=cname,
pf_cname=pf_cname,
qualified_name=node.local_scope.qualified_name,
lineno=str(node.pos[1]))
@@ -3428,10 +4045,10 @@ class DebugTransform(CythonTransform):
def visit_NameNode(self, node):
if (self.register_stepinto and
- node.type is not None and
- node.type.is_cfunction and
- getattr(node, 'is_called', False) and
- node.entry.func_cname is not None):
+ node.type is not None and
+ node.type.is_cfunction and
+ getattr(node, 'is_called', False) and
+ node.entry.func_cname is not None):
# don't check node.entry.in_cinclude, as 'cdef extern: ...'
# declared functions are not 'in_cinclude'.
# This means we will list called 'cdef' functions as
@@ -3450,26 +4067,16 @@ class DebugTransform(CythonTransform):
it's a "relevant frame" and it will know where to set the breakpoint
for 'break modulename'.
"""
- name = node.full_module_name.rpartition('.')[-1]
-
- cname_py2 = 'init' + name
- cname_py3 = 'PyInit_' + name
-
- py2_attrs = dict(
- name=name,
- cname=cname_py2,
+ self._serialize_modulenode_as_function(node, dict(
+ name=node.full_module_name.rpartition('.')[-1],
+ cname=node.module_init_func_cname(),
pf_cname='',
# Ignore the qualified_name, breakpoints should be set using
# `cy break modulename:lineno` for module-level breakpoints.
qualified_name='',
lineno='1',
is_initmodule_function="True",
- )
-
- py3_attrs = dict(py2_attrs, cname=cname_py3)
-
- self._serialize_modulenode_as_function(node, py2_attrs)
- self._serialize_modulenode_as_function(node, py3_attrs)
+ ))
def _serialize_modulenode_as_function(self, node, attrs):
self.tb.start('Function', attrs=attrs)