diff options
Diffstat (limited to 'Cython/Compiler/ParseTreeTransforms.py')
-rw-r--r-- | Cython/Compiler/ParseTreeTransforms.py | 995 |
1 files changed, 801 insertions, 194 deletions
diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index 0e86d5b0e..8008cba9a 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -1,3 +1,5 @@ +# cython: language_level=3str + from __future__ import absolute_import import cython @@ -54,6 +56,12 @@ class SkipDeclarations(object): def visit_CStructOrUnionDefNode(self, node): return node + def visit_CppClassNode(self, node): + if node.visibility != "extern": + # Need to traverse methods. + self.visitchildren(node) + return node + class NormalizeTree(CythonTransform): """ @@ -81,6 +89,13 @@ class NormalizeTree(CythonTransform): self.is_in_statlist = False self.is_in_expr = False + def visit_ModuleNode(self, node): + self.visitchildren(node) + if not isinstance(node.body, Nodes.StatListNode): + # This can happen when the body only consists of a single (unused) declaration and no statements. + node.body = Nodes.StatListNode(pos=node.pos, stats=[node.body]) + return node + def visit_ExprNode(self, node): stacktmp = self.is_in_expr self.is_in_expr = True @@ -170,8 +185,9 @@ class PostParse(ScopeTrackingTransform): Note: Currently Parsing.py does a lot of interpretation and reorganization that can be refactored into this transform if a more pure Abstract Syntax Tree is wanted. - """ + - Some invalid uses of := assignment expressions are detected + """ def __init__(self, context): super(PostParse, self).__init__(context) self.specialattribute_handlers = { @@ -203,7 +219,9 @@ class PostParse(ScopeTrackingTransform): node.def_node = Nodes.DefNode( node.pos, name=node.name, doc=None, args=[], star_arg=None, starstar_arg=None, - body=node.loop, is_async_def=collector.has_await) + body=node.loop, is_async_def=collector.has_await, + is_generator_expression=True) + _AssignmentExpressionChecker.do_checks(node.loop, scope_is_class=self.scope_type in ("pyclass", "cclass")) self.visitchildren(node) return node @@ -214,6 +232,7 @@ class PostParse(ScopeTrackingTransform): collector.visitchildren(node.loop) if collector.has_await: node.has_local_scope = True + _AssignmentExpressionChecker.do_checks(node.loop, scope_is_class=self.scope_type in ("pyclass", "cclass")) self.visitchildren(node) return node @@ -246,7 +265,7 @@ class PostParse(ScopeTrackingTransform): if decl is not declbase: raise PostParseError(decl.pos, ERR_INVALID_SPECIALATTR_TYPE) handler(decl) - continue # Remove declaration + continue # Remove declaration raise PostParseError(decl.pos, ERR_CDEF_INCLASS) first_assignment = self.scope_type != 'module' stats.append(Nodes.SingleAssignmentNode(node.pos, @@ -349,6 +368,141 @@ class PostParse(ScopeTrackingTransform): self.visitchildren(node) return node + def visit_AssertStatNode(self, node): + """Extract the exception raising into a RaiseStatNode to simplify GIL handling. + """ + if node.exception is None: + node.exception = Nodes.RaiseStatNode( + node.pos, + exc_type=ExprNodes.NameNode(node.pos, name=EncodedString("AssertionError")), + exc_value=node.value, + exc_tb=None, + cause=None, + builtin_exc_name="AssertionError", + wrap_tuple_value=True, + ) + node.value = None + self.visitchildren(node) + return node + +class _AssignmentExpressionTargetNameFinder(TreeVisitor): + def __init__(self): + super(_AssignmentExpressionTargetNameFinder, self).__init__() + self.target_names = {} + + def find_target_names(self, target): + if target.is_name: + return [target.name] + elif target.is_sequence_constructor: + names = [] + for arg in target.args: + names.extend(self.find_target_names(arg)) + return names + # other targets are possible, but it isn't necessary to investigate them here + return [] + + def visit_ForInStatNode(self, node): + self.target_names[node] = tuple(self.find_target_names(node.target)) + self.visitchildren(node) + + def visit_ComprehensionNode(self, node): + pass # don't recurse into nested comprehensions + + def visit_LambdaNode(self, node): + pass # don't recurse into nested lambdas/generator expressions + + def visit_Node(self, node): + self.visitchildren(node) + + +class _AssignmentExpressionChecker(TreeVisitor): + """ + Enforces rules on AssignmentExpressions within generator expressions and comprehensions + """ + def __init__(self, loop_node, scope_is_class): + super(_AssignmentExpressionChecker, self).__init__() + + target_name_finder = _AssignmentExpressionTargetNameFinder() + target_name_finder.visit(loop_node) + self.target_names_dict = target_name_finder.target_names + self.in_iterator = False + self.in_nested_generator = False + self.scope_is_class = scope_is_class + self.current_target_names = () + self.all_target_names = set() + for names in self.target_names_dict.values(): + self.all_target_names.update(names) + + def _reset_state(self): + old_state = (self.in_iterator, self.in_nested_generator, self.scope_is_class, self.all_target_names, self.current_target_names) + # note: not resetting self.in_iterator here, see visit_LambdaNode() below + self.in_nested_generator = False + self.scope_is_class = False + self.current_target_names = () + self.all_target_names = set() + return old_state + + def _set_state(self, old_state): + self.in_iterator, self.in_nested_generator, self.scope_is_class, self.all_target_names, self.current_target_names = old_state + + @classmethod + def do_checks(cls, loop_node, scope_is_class): + checker = cls(loop_node, scope_is_class) + checker.visit(loop_node) + + def visit_ForInStatNode(self, node): + if self.in_nested_generator: + self.visitchildren(node) # once nested, don't do anything special + return + + current_target_names = self.current_target_names + target_name = self.target_names_dict.get(node, None) + if target_name: + self.current_target_names += target_name + + self.in_iterator = True + self.visit(node.iterator) + self.in_iterator = False + self.visitchildren(node, exclude=("iterator",)) + + self.current_target_names = current_target_names + + def visit_AssignmentExpressionNode(self, node): + if self.in_iterator: + error(node.pos, "assignment expression cannot be used in a comprehension iterable expression") + if self.scope_is_class: + error(node.pos, "assignment expression within a comprehension cannot be used in a class body") + if node.target_name in self.current_target_names: + error(node.pos, "assignment expression cannot rebind comprehension iteration variable '%s'" % + node.target_name) + elif node.target_name in self.all_target_names: + error(node.pos, "comprehension inner loop cannot rebind assignment expression target '%s'" % + node.target_name) + + def visit_LambdaNode(self, node): + # Don't reset "in_iterator" - an assignment expression in a lambda in an + # iterator is explicitly tested by the Python testcases and banned. + old_state = self._reset_state() + # the lambda node's "def_node" is not set up at this point, so we need to recurse into it explicitly. + self.visit(node.result_expr) + self._set_state(old_state) + + def visit_ComprehensionNode(self, node): + in_nested_generator = self.in_nested_generator + self.in_nested_generator = True + self.visitchildren(node) + self.in_nested_generator = in_nested_generator + + def visit_GeneratorExpressionNode(self, node): + in_nested_generator = self.in_nested_generator + self.in_nested_generator = True + # def_node isn't set up yet, so we need to visit the loop directly. + self.visit(node.loop) + self.in_nested_generator = in_nested_generator + + def visit_Node(self, node): + self.visitchildren(node) + def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence): """Replace rhs items by LetRefNodes if they appear more than once. @@ -419,7 +573,7 @@ def sort_common_subsequences(items): return b.is_sequence_constructor and contains(b.args, a) for pos, item in enumerate(items): - key = item[1] # the ResultRefNode which has already been injected into the sequences + key = item[1] # the ResultRefNode which has already been injected into the sequences new_pos = pos for i in range(pos-1, -1, -1): if lower_than(key, items[i][0]): @@ -449,7 +603,7 @@ def flatten_parallel_assignments(input, output): # recursively, so that nested structures get matched as well. rhs = input[-1] if (not (rhs.is_sequence_constructor or isinstance(rhs, ExprNodes.UnicodeNode)) - or not sum([lhs.is_sequence_constructor for lhs in input[:-1]])): + or not sum([lhs.is_sequence_constructor for lhs in input[:-1]])): output.append(input) return @@ -533,7 +687,7 @@ def map_starred_assignment(lhs_targets, starred_assignments, lhs_args, rhs_args) targets.append(expr) # the starred target itself, must be assigned a (potentially empty) list - target = lhs_args[starred].target # unpack starred node + target = lhs_args[starred].target # unpack starred node starred_rhs = rhs_args[starred:] if lhs_remaining: starred_rhs = starred_rhs[:-lhs_remaining] @@ -579,19 +733,19 @@ class PxdPostParse(CythonTransform, SkipDeclarations): err = self.ERR_INLINE_ONLY if (isinstance(node, Nodes.DefNode) and self.scope_type == 'cclass' - and node.name in ('__getbuffer__', '__releasebuffer__')): - err = None # allow these slots + and node.name in ('__getbuffer__', '__releasebuffer__')): + err = None # allow these slots if isinstance(node, Nodes.CFuncDefNode): if (u'inline' in node.modifiers and - self.scope_type in ('pxd', 'cclass')): + self.scope_type in ('pxd', 'cclass')): node.inline_in_pxd = True if node.visibility != 'private': err = self.ERR_NOGO_WITH_INLINE % node.visibility elif node.api: err = self.ERR_NOGO_WITH_INLINE % 'api' else: - err = None # allow inline function + err = None # allow inline function else: err = self.ERR_INLINE_ONLY @@ -630,6 +784,9 @@ class InterpretCompilerDirectives(CythonTransform): - Command-line arguments overriding these - @cython.directivename decorators - with cython.directivename: statements + - replaces "cython.compiled" with BoolNode(value=True) + allowing unreachable blocks to be removed at a fairly early stage + before cython typing rules are forced on applied This transform is responsible for interpreting these various sources and store the directive in two ways: @@ -668,17 +825,27 @@ class InterpretCompilerDirectives(CythonTransform): 'operator.comma' : ExprNodes.c_binop_constructor(','), } - special_methods = set(['declare', 'union', 'struct', 'typedef', - 'sizeof', 'cast', 'pointer', 'compiled', - 'NULL', 'fused_type', 'parallel']) + special_methods = { + 'declare', 'union', 'struct', 'typedef', + 'sizeof', 'cast', 'pointer', 'compiled', + 'NULL', 'fused_type', 'parallel', + } special_methods.update(unop_method_nodes) - valid_parallel_directives = set([ + valid_cython_submodules = { + 'cimports', + 'dataclasses', + 'operator', + 'parallel', + 'view', + } + + valid_parallel_directives = { "parallel", "prange", "threadid", #"threadsavailable", - ]) + } def __init__(self, context, compilation_directive_defaults): super(InterpretCompilerDirectives, self).__init__(context) @@ -701,6 +868,44 @@ class InterpretCompilerDirectives(CythonTransform): error(pos, "Invalid directive: '%s'." % (directive,)) return True + def _check_valid_cython_module(self, pos, module_name): + if not module_name.startswith("cython."): + return + submodule = module_name.split('.', 2)[1] + if submodule in self.valid_cython_submodules: + return + + extra = "" + # This is very rarely used, so don't waste space on static tuples. + hints = [ + line.split() for line in """\ + imp cimports + cimp cimports + para parallel + parra parallel + dataclass dataclasses + """.splitlines()[:-1] + ] + for wrong, correct in hints: + if module_name.startswith("cython." + wrong): + extra = "Did you mean 'cython.%s' ?" % correct + break + if not extra: + is_simple_cython_name = submodule in Options.directive_types + if not is_simple_cython_name and not submodule.startswith("_"): + # Try to find it in the Shadow module (i.e. the pure Python namespace of cython.*). + # FIXME: use an internal reference of "cython.*" names instead of Shadow.py + from .. import Shadow + is_simple_cython_name = hasattr(Shadow, submodule) + if is_simple_cython_name: + extra = "Instead, use 'import cython' and then 'cython.%s'." % submodule + + error(pos, "'%s' is not a valid cython.* module%s%s" % ( + module_name, + ". " if extra else "", + extra, + )) + # Set up processing and handle the cython: comments. def visit_ModuleNode(self, node): for key in sorted(node.directive_comments): @@ -717,6 +922,12 @@ class InterpretCompilerDirectives(CythonTransform): node.cython_module_names = self.cython_module_names return node + def visit_CompilerDirectivesNode(self, node): + old_directives, self.directives = self.directives, node.directives + self.visitchildren(node) + self.directives = old_directives + return node + # The following four functions track imports and cimports that # begin with "cython" def is_cython_directive(self, name): @@ -749,22 +960,36 @@ class InterpretCompilerDirectives(CythonTransform): return result def visit_CImportStatNode(self, node): - if node.module_name == u"cython": + module_name = node.module_name + if module_name == u"cython.cimports": + error(node.pos, "Cannot cimport the 'cython.cimports' package directly, only submodules.") + if module_name.startswith(u"cython.cimports."): + if node.as_name and node.as_name != u'cython': + node.module_name = module_name[len(u"cython.cimports."):] + return node + error(node.pos, + "Python cimports must use 'from cython.cimports... import ...'" + " or 'import ... as ...', not just 'import ...'") + + if module_name == u"cython": self.cython_module_names.add(node.as_name or u"cython") - elif node.module_name.startswith(u"cython."): - if node.module_name.startswith(u"cython.parallel."): + elif module_name.startswith(u"cython."): + if module_name.startswith(u"cython.parallel."): error(node.pos, node.module_name + " is not a module") - if node.module_name == u"cython.parallel": + else: + self._check_valid_cython_module(node.pos, module_name) + + if module_name == u"cython.parallel": if node.as_name and node.as_name != u"cython": - self.parallel_directives[node.as_name] = node.module_name + self.parallel_directives[node.as_name] = module_name else: self.cython_module_names.add(u"cython") self.parallel_directives[ - u"cython.parallel"] = node.module_name + u"cython.parallel"] = module_name self.module_scope.use_utility_code( UtilityCode.load_cached("InitThreads", "ModuleSetupCode.c")) elif node.as_name: - self.directive_names[node.as_name] = node.module_name[7:] + self.directive_names[node.as_name] = module_name[7:] else: self.cython_module_names.add(u"cython") # if this cimport was a compiler directive, we don't @@ -773,26 +998,31 @@ class InterpretCompilerDirectives(CythonTransform): return node def visit_FromCImportStatNode(self, node): - if not node.relative_level and ( - node.module_name == u"cython" or node.module_name.startswith(u"cython.")): - submodule = (node.module_name + u".")[7:] + module_name = node.module_name + if module_name == u"cython.cimports" or module_name.startswith(u"cython.cimports."): + # only supported for convenience + return self._create_cimport_from_import( + node.pos, module_name, node.relative_level, node.imported_names) + elif not node.relative_level and ( + module_name == u"cython" or module_name.startswith(u"cython.")): + self._check_valid_cython_module(node.pos, module_name) + submodule = (module_name + u".")[7:] newimp = [] - - for pos, name, as_name, kind in node.imported_names: + for pos, name, as_name in node.imported_names: full_name = submodule + name qualified_name = u"cython." + full_name - if self.is_parallel_directive(qualified_name, node.pos): # from cython cimport parallel, or # from cython.parallel cimport parallel, prange, ... self.parallel_directives[as_name or name] = qualified_name elif self.is_cython_directive(full_name): self.directive_names[as_name or name] = full_name - if kind is not None: - self.context.nonfatal_error(PostParseError(pos, - "Compiler directive imports must be plain imports")) + elif full_name in ['dataclasses', 'typing']: + self.directive_names[as_name or name] = full_name + # unlike many directives, still treat it as a regular module + newimp.append((pos, name, as_name)) else: - newimp.append((pos, name, as_name, kind)) + newimp.append((pos, name, as_name)) if not newimp: return None @@ -801,9 +1031,18 @@ class InterpretCompilerDirectives(CythonTransform): return node def visit_FromImportStatNode(self, node): - if (node.module.module_name.value == u"cython") or \ - node.module.module_name.value.startswith(u"cython."): - submodule = (node.module.module_name.value + u".")[7:] + import_node = node.module + module_name = import_node.module_name.value + if module_name == u"cython.cimports" or module_name.startswith(u"cython.cimports."): + imported_names = [] + for name, name_node in node.items: + imported_names.append( + (name_node.pos, name, None if name == name_node.name else name_node.name)) + return self._create_cimport_from_import( + node.pos, module_name, import_node.level, imported_names) + elif module_name == u"cython" or module_name.startswith(u"cython."): + self._check_valid_cython_module(import_node.module_name.pos, module_name) + submodule = (module_name + u".")[7:] newimp = [] for name, name_node in node.items: full_name = submodule + name @@ -819,20 +1058,34 @@ class InterpretCompilerDirectives(CythonTransform): node.items = newimp return node + def _create_cimport_from_import(self, node_pos, module_name, level, imported_names): + if module_name == u"cython.cimports" or module_name.startswith(u"cython.cimports."): + module_name = EncodedString(module_name[len(u"cython.cimports."):]) # may be empty + + if module_name: + # from cython.cimports.a.b import x, y, z => from a.b cimport x, y, z + return Nodes.FromCImportStatNode( + node_pos, module_name=module_name, + relative_level=level, + imported_names=imported_names) + else: + # from cython.cimports import x, y, z => cimport x; cimport y; cimport z + return [ + Nodes.CImportStatNode( + pos, + module_name=dotted_name, + as_name=as_name, + is_absolute=level == 0) + for pos, dotted_name, as_name in imported_names + ] + def visit_SingleAssignmentNode(self, node): if isinstance(node.rhs, ExprNodes.ImportNode): module_name = node.rhs.module_name.value - is_parallel = (module_name + u".").startswith(u"cython.parallel.") - - if module_name != u"cython" and not is_parallel: + if module_name != u"cython" and not module_name.startswith("cython."): return node - module_name = node.rhs.module_name.value - as_name = node.lhs.name - - node = Nodes.CImportStatNode(node.pos, - module_name = module_name, - as_name = as_name) + node = Nodes.CImportStatNode(node.pos, module_name=module_name, as_name=node.lhs.name) node = self.visit_CImportStatNode(node) else: self.visitchildren(node) @@ -840,16 +1093,35 @@ class InterpretCompilerDirectives(CythonTransform): return node def visit_NameNode(self, node): + if node.annotation: + self.visitchild(node, 'annotation') if node.name in self.cython_module_names: node.is_cython_module = True else: directive = self.directive_names.get(node.name) if directive is not None: node.cython_attribute = directive + if node.as_cython_attribute() == "compiled": + return ExprNodes.BoolNode(node.pos, value=True) # replace early so unused branches can be dropped + # before they have a chance to cause compile-errors + return node + + def visit_AttributeNode(self, node): + self.visitchildren(node) + if node.as_cython_attribute() == "compiled": + return ExprNodes.BoolNode(node.pos, value=True) # replace early so unused branches can be dropped + # before they have a chance to cause compile-errors + return node + + def visit_AnnotationNode(self, node): + # for most transforms annotations are left unvisited (because they're unevaluated) + # however, it is important to pick up compiler directives from them + if node.expr: + self.visit(node.expr) return node def visit_NewExprNode(self, node): - self.visit(node.cppclass) + self.visitchild(node, 'cppclass') self.visitchildren(node) return node @@ -858,7 +1130,7 @@ class InterpretCompilerDirectives(CythonTransform): # decorator), returns a list of (directivename, value) pairs. # Otherwise, returns None if isinstance(node, ExprNodes.CallNode): - self.visit(node.function) + self.visitchild(node, 'function') optname = node.function.as_cython_attribute() if optname: directivetype = Options.directive_types.get(optname) @@ -890,7 +1162,7 @@ class InterpretCompilerDirectives(CythonTransform): if directivetype is bool: arg = ExprNodes.BoolNode(node.pos, value=True) return [self.try_to_parse_directive(optname, [arg], None, node.pos)] - elif directivetype is None: + elif directivetype is None or directivetype is Options.DEFER_ANALYSIS_OF_ARGUMENTS: return [(optname, None)] else: raise PostParseError( @@ -945,7 +1217,7 @@ class InterpretCompilerDirectives(CythonTransform): if len(args) != 0: raise PostParseError(pos, 'The %s directive takes no prepositional arguments' % optname) - return optname, dict([(key.value, value) for key, value in kwds.key_value_pairs]) + return optname, kwds.as_python_dict() elif directivetype is list: if kwds and len(kwds.key_value_pairs) != 0: raise PostParseError(pos, @@ -957,21 +1229,42 @@ class InterpretCompilerDirectives(CythonTransform): raise PostParseError(pos, 'The %s directive takes one compile-time string argument' % optname) return (optname, directivetype(optname, str(args[0].value))) + elif directivetype is Options.DEFER_ANALYSIS_OF_ARGUMENTS: + # signal to pass things on without processing + return (optname, (args, kwds.as_python_dict() if kwds else {})) else: assert False - def visit_with_directives(self, node, directives): + def visit_with_directives(self, node, directives, contents_directives): + # contents_directives may be None if not directives: + assert not contents_directives return self.visit_Node(node) old_directives = self.directives - new_directives = dict(old_directives) - new_directives.update(directives) + new_directives = Options.copy_inherited_directives(old_directives, **directives) + if contents_directives is not None: + new_contents_directives = Options.copy_inherited_directives( + old_directives, **contents_directives) + else: + new_contents_directives = new_directives if new_directives == old_directives: return self.visit_Node(node) self.directives = new_directives + if (contents_directives is not None and + new_contents_directives != new_directives): + # we need to wrap the node body in a compiler directives node + node.body = Nodes.StatListNode( + node.body.pos, + stats=[ + Nodes.CompilerDirectivesNode( + node.body.pos, + directives=new_contents_directives, + body=node.body) + ] + ) retbody = self.visit_Node(node) self.directives = old_directives @@ -980,13 +1273,14 @@ class InterpretCompilerDirectives(CythonTransform): return Nodes.CompilerDirectivesNode( pos=retbody.pos, body=retbody, directives=new_directives) + # Handle decorators def visit_FuncDefNode(self, node): - directives = self._extract_directives(node, 'function') - return self.visit_with_directives(node, directives) + directives, contents_directives = self._extract_directives(node, 'function') + return self.visit_with_directives(node, directives, contents_directives) def visit_CVarDefNode(self, node): - directives = self._extract_directives(node, 'function') + directives, _ = self._extract_directives(node, 'function') for name, value in directives.items(): if name == 'locals': node.directive_locals = value @@ -995,27 +1289,34 @@ class InterpretCompilerDirectives(CythonTransform): node.pos, "Cdef functions can only take cython.locals(), " "staticmethod, or final decorators, got %s." % name)) - return self.visit_with_directives(node, directives) + return self.visit_with_directives(node, directives, contents_directives=None) def visit_CClassDefNode(self, node): - directives = self._extract_directives(node, 'cclass') - return self.visit_with_directives(node, directives) + directives, contents_directives = self._extract_directives(node, 'cclass') + return self.visit_with_directives(node, directives, contents_directives) def visit_CppClassNode(self, node): - directives = self._extract_directives(node, 'cppclass') - return self.visit_with_directives(node, directives) + directives, contents_directives = self._extract_directives(node, 'cppclass') + return self.visit_with_directives(node, directives, contents_directives) def visit_PyClassDefNode(self, node): - directives = self._extract_directives(node, 'class') - return self.visit_with_directives(node, directives) + directives, contents_directives = self._extract_directives(node, 'class') + return self.visit_with_directives(node, directives, contents_directives) def _extract_directives(self, node, scope_name): + """ + Returns two dicts - directives applied to this function/class + and directives applied to its contents. They aren't always the + same (since e.g. cfunc should not be applied to inner functions) + """ if not node.decorators: - return {} + return {}, {} # Split the decorators into two lists -- real decorators and directives directives = [] realdecs = [] both = [] + current_opt_dict = dict(self.directives) + missing = object() # Decorators coming first take precedence. for dec in node.decorators[::-1]: new_directives = self.try_to_parse_directives(dec.decorator) @@ -1023,8 +1324,14 @@ class InterpretCompilerDirectives(CythonTransform): for directive in new_directives: if self.check_directive_scope(node.pos, directive[0], scope_name): name, value = directive - if self.directives.get(name, object()) != value: + if current_opt_dict.get(name, missing) != value: + if name == 'cfunc' and 'ufunc' in current_opt_dict: + error(dec.pos, "Cannot apply @cfunc to @ufunc, please reverse the decorators.") directives.append(directive) + current_opt_dict[name] = value + else: + warning(dec.pos, "Directive does not change previous value (%s%s)" % ( + name, '=%r' % value if value is not None else '')) if directive[0] == 'staticmethod': both.append(dec) # Adapt scope type based on decorators that change it. @@ -1033,13 +1340,21 @@ class InterpretCompilerDirectives(CythonTransform): else: realdecs.append(dec) if realdecs and (scope_name == 'cclass' or - isinstance(node, (Nodes.CFuncDefNode, Nodes.CClassDefNode, Nodes.CVarDefNode))): + isinstance(node, (Nodes.CClassDefNode, Nodes.CVarDefNode))): + for realdec in realdecs: + dec_pos = realdec.pos + realdec = realdec.decorator + if ((realdec.is_name and realdec.name == "dataclass") or + (realdec.is_attribute and realdec.attribute == "dataclass")): + error(dec_pos, + "Use '@cython.dataclasses.dataclass' on cdef classes to create a dataclass") + # Note - arbitrary C function decorators are caught later in DecoratorTransform raise PostParseError(realdecs[0].pos, "Cdef functions/classes cannot take arbitrary decorators.") node.decorators = realdecs[::-1] + both[::-1] # merge or override repeated directives optdict = {} - for directive in directives: - name, value = directive + contents_optdict = {} + for name, value in directives: if name in optdict: old_value = optdict[name] # keywords and arg lists can be merged, everything @@ -1052,7 +1367,9 @@ class InterpretCompilerDirectives(CythonTransform): optdict[name] = value else: optdict[name] = value - return optdict + if name not in Options.immediate_decorator_directives: + contents_optdict[name] = value + return optdict, contents_optdict # Handle with-statements def visit_WithStatNode(self, node): @@ -1071,7 +1388,7 @@ class InterpretCompilerDirectives(CythonTransform): if self.check_directive_scope(node.pos, name, 'with statement'): directive_dict[name] = value if directive_dict: - return self.visit_with_directives(node.body, directive_dict) + return self.visit_with_directives(node.body, directive_dict, contents_directives=None) return self.visit_Node(node) @@ -1161,7 +1478,7 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations): return node def visit_CallNode(self, node): - self.visit(node.function) + self.visitchild(node, 'function') if not self.parallel_directive: self.visitchildren(node, exclude=('function',)) return node @@ -1194,7 +1511,7 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations): "Nested parallel with blocks are disallowed") self.state = 'parallel with' - body = self.visit(node.body) + body = self.visitchild(node, 'body') self.state = None newnode.body = body @@ -1210,13 +1527,13 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations): error(node.pos, "The parallel directive must be called") return None - node.body = self.visit(node.body) + self.visitchild(node, 'body') return node def visit_ForInStatNode(self, node): "Rewrite 'for i in cython.parallel.prange(...):'" - self.visit(node.iterator) - self.visit(node.target) + self.visitchild(node, 'iterator') + self.visitchild(node, 'target') in_prange = isinstance(node.iterator.sequence, Nodes.ParallelRangeNode) @@ -1239,9 +1556,9 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations): self.state = 'prange' - self.visit(node.body) + self.visitchild(node, 'body') self.state = previous_state - self.visit(node.else_clause) + self.visitchild(node, 'else_clause') return node def visit(self, node): @@ -1250,12 +1567,13 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations): return super(ParallelRangeTransform, self).visit(node) -class WithTransform(CythonTransform, SkipDeclarations): +class WithTransform(VisitorTransform, SkipDeclarations): def visit_WithStatNode(self, node): self.visitchildren(node, 'body') pos = node.pos is_async = node.is_async body, target, manager = node.body, node.target, node.manager + manager = node.manager = ExprNodes.ProxyNode(manager) node.enter_call = ExprNodes.SimpleCallNode( pos, function=ExprNodes.AttributeNode( pos, obj=ExprNodes.CloneNode(manager), @@ -1316,6 +1634,130 @@ class WithTransform(CythonTransform, SkipDeclarations): # With statements are never inside expressions. return node + visit_Node = VisitorTransform.recurse_to_children + + +class _GeneratorExpressionArgumentsMarker(TreeVisitor, SkipDeclarations): + # called from "MarkClosureVisitor" + def __init__(self, gen_expr): + super(_GeneratorExpressionArgumentsMarker, self).__init__() + self.gen_expr = gen_expr + + def visit_ExprNode(self, node): + if not node.is_literal: + # Don't bother tagging literal nodes + assert (not node.generator_arg_tag) # nobody has tagged this first + node.generator_arg_tag = self.gen_expr + self.visitchildren(node) + + def visit_Node(self, node): + # We're only interested in the expressions that make up the iterator sequence, + # so don't go beyond ExprNodes (e.g. into ForFromStatNode). + return + + def visit_GeneratorExpressionNode(self, node): + node.generator_arg_tag = self.gen_expr + # don't visit children, can't handle overlapping tags + # (and assume generator expressions don't end up optimized out in a way + # that would require overlapping tags) + + +class _HandleGeneratorArguments(VisitorTransform, SkipDeclarations): + # used from within CreateClosureClasses + + def __call__(self, node): + from . import Visitor + assert isinstance(node, ExprNodes.GeneratorExpressionNode) + self.gen_node = node + + self.args = list(node.def_node.args) + self.call_parameters = list(node.call_parameters) + self.tag_count = 0 + self.substitutions = {} + + self.visitchildren(node) + + for k, v in self.substitutions.items(): + # doing another search for replacements here (at the end) allows us to sweep up + # CloneNodes too (which are often generated by the optimizer) + # (it could arguably be done more efficiently with a single traversal though) + Visitor.recursively_replace_node(node, k, v) + + node.def_node.args = self.args + node.call_parameters = self.call_parameters + return node + + def visit_GeneratorExpressionNode(self, node): + # a generator can also be substituted itself, so handle that case + new_node = self._handle_ExprNode(node, do_visit_children=False) + # However do not traverse into it. A new _HandleGeneratorArguments visitor will be used + # elsewhere to do that. + return node + + def _handle_ExprNode(self, node, do_visit_children): + if (node.generator_arg_tag is not None and self.gen_node is not None and + self.gen_node == node.generator_arg_tag): + pos = node.pos + # The reason for using ".x" as the name is that this is how CPython + # tracks internal variables in loops (e.g. + # { locals() for v in range(10) } + # will produce "v" and ".0"). We don't replicate this behaviour completely + # but use it as a starting point + name_source = self.tag_count + self.tag_count += 1 + name = EncodedString(".{0}".format(name_source)) + def_node = self.gen_node.def_node + if not def_node.local_scope.lookup_here(name): + from . import Symtab + cname = EncodedString(Naming.genexpr_arg_prefix + Symtab.punycodify_name(str(name_source))) + name_decl = Nodes.CNameDeclaratorNode(pos=pos, name=name) + type = node.type + if type.is_reference and not type.is_fake_reference: + # It isn't obvious whether the right thing to do would be to capture by reference or by + # value (C++ itself doesn't know either for lambda functions and forces a choice). + # However, capture by reference involves converting to FakeReference which would require + # re-analysing AttributeNodes. Therefore I've picked capture-by-value out of convenience + # TODO - could probably be optimized by making the arg a reference but the closure not + # (see https://github.com/cython/cython/issues/2468) + type = type.ref_base_type + + name_decl.type = type + new_arg = Nodes.CArgDeclNode(pos=pos, declarator=name_decl, + base_type=None, default=None, annotation=None) + new_arg.name = name_decl.name + new_arg.type = type + + self.args.append(new_arg) + node.generator_arg_tag = None # avoid the possibility of this being caught again + self.call_parameters.append(node) + new_arg.entry = def_node.declare_argument(def_node.local_scope, new_arg) + new_arg.entry.cname = cname + new_arg.entry.in_closure = True + + if do_visit_children: + # now visit the Nodes's children (but remove self.gen_node to not to further + # argument substitution) + gen_node, self.gen_node = self.gen_node, None + self.visitchildren(node) + self.gen_node = gen_node + + # replace the node inside the generator with a looked-up name + # (initialized_check can safely be False because the source variable will be checked + # before it is captured if the check is required) + name_node = ExprNodes.NameNode(pos, name=name, initialized_check=False) + name_node.entry = self.gen_node.def_node.gbody.local_scope.lookup(name_node.name) + name_node.type = name_node.entry.type + self.substitutions[node] = name_node + return name_node + if do_visit_children: + self.visitchildren(node) + return node + + def visit_ExprNode(self, node): + return self._handle_ExprNode(node, True) + + visit_Node = VisitorTransform.recurse_to_children + class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations): """ @@ -1328,16 +1770,16 @@ class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations): _properties = None _map_property_attribute = { - 'getter': '__get__', - 'setter': '__set__', - 'deleter': '__del__', + 'getter': EncodedString('__get__'), + 'setter': EncodedString('__set__'), + 'deleter': EncodedString('__del__'), }.get def visit_CClassDefNode(self, node): if self._properties is None: self._properties = [] self._properties.append({}) - super(DecoratorTransform, self).visit_CClassDefNode(node) + node = super(DecoratorTransform, self).visit_CClassDefNode(node) self._properties.pop() return node @@ -1347,6 +1789,32 @@ class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations): warning(node.pos, "'property %s:' syntax is deprecated, use '@property'" % node.name, level) return node + def visit_CFuncDefNode(self, node): + node = self.visit_FuncDefNode(node) + if not node.decorators: + return node + elif self.scope_type != 'cclass' or self.scope_node.visibility != "extern": + # at the moment cdef functions are very restricted in what decorators they can take + # so it's simple to test for the small number of allowed decorators.... + if not (len(node.decorators) == 1 and node.decorators[0].decorator.is_name and + node.decorators[0].decorator.name == "staticmethod"): + error(node.decorators[0].pos, "Cdef functions cannot take arbitrary decorators.") + return node + + ret_node = node + decorator_node = self._find_property_decorator(node) + if decorator_node: + if decorator_node.decorator.is_name: + name = node.declared_name() + if name: + ret_node = self._add_property(node, name, decorator_node) + else: + error(decorator_node.pos, "C property decorator can only be @property") + + if node.decorators: + return self._reject_decorated_property(node, node.decorators[0]) + return ret_node + def visit_DefNode(self, node): scope_type = self.scope_type node = self.visit_FuncDefNode(node) @@ -1354,28 +1822,12 @@ class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations): return node # transform @property decorators - properties = self._properties[-1] - for decorator_node in node.decorators[::-1]: + decorator_node = self._find_property_decorator(node) + if decorator_node is not None: decorator = decorator_node.decorator - if decorator.is_name and decorator.name == 'property': - if len(node.decorators) > 1: - return self._reject_decorated_property(node, decorator_node) - name = node.name - node.name = EncodedString('__get__') - node.decorators.remove(decorator_node) - stat_list = [node] - if name in properties: - prop = properties[name] - prop.pos = node.pos - prop.doc = node.doc - prop.body.stats = stat_list - return [] - prop = Nodes.PropertyNode(node.pos, name=name) - prop.doc = node.doc - prop.body = Nodes.StatListNode(node.pos, stats=stat_list) - properties[name] = prop - return [prop] - elif decorator.is_attribute and decorator.obj.name in properties: + if decorator.is_name: + return self._add_property(node, node.name, decorator_node) + else: handler_name = self._map_property_attribute(decorator.attribute) if handler_name: if decorator.obj.name != node.name: @@ -1386,7 +1838,7 @@ class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations): elif len(node.decorators) > 1: return self._reject_decorated_property(node, decorator_node) else: - return self._add_to_property(properties, node, handler_name, decorator_node) + return self._add_to_property(node, handler_name, decorator_node) # we clear node.decorators, so we need to set the # is_staticmethod/is_classmethod attributes now @@ -1401,6 +1853,18 @@ class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations): node.decorators = None return self.chain_decorators(node, decs, node.name) + def _find_property_decorator(self, node): + properties = self._properties[-1] + for decorator_node in node.decorators[::-1]: + decorator = decorator_node.decorator + if decorator.is_name and decorator.name == 'property': + # @property + return decorator_node + elif decorator.is_attribute and decorator.obj.name in properties: + # @prop.setter etc. + return decorator_node + return None + @staticmethod def _reject_decorated_property(node, decorator_node): # restrict transformation to outermost decorator as wrapped properties will probably not work @@ -1409,9 +1873,42 @@ class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations): error(deco.pos, "Property methods with additional decorators are not supported") return node - @staticmethod - def _add_to_property(properties, node, name, decorator): + def _add_property(self, node, name, decorator_node): + if len(node.decorators) > 1: + return self._reject_decorated_property(node, decorator_node) + node.decorators.remove(decorator_node) + properties = self._properties[-1] + is_cproperty = isinstance(node, Nodes.CFuncDefNode) + body = Nodes.StatListNode(node.pos, stats=[node]) + if is_cproperty: + if name in properties: + error(node.pos, "C property redeclared") + if 'inline' not in node.modifiers: + error(node.pos, "C property method must be declared 'inline'") + prop = Nodes.CPropertyNode(node.pos, doc=node.doc, name=name, body=body) + elif name in properties: + prop = properties[name] + if prop.is_cproperty: + error(node.pos, "C property redeclared") + else: + node.name = EncodedString("__get__") + prop.pos = node.pos + prop.doc = node.doc + prop.body.stats = [node] + return None + else: + node.name = EncodedString("__get__") + prop = Nodes.PropertyNode( + node.pos, name=name, doc=node.doc, body=body) + properties[name] = prop + return prop + + def _add_to_property(self, node, name, decorator): + properties = self._properties[-1] prop = properties[node.name] + if prop.is_cproperty: + error(node.pos, "C property redeclared") + return None node.name = name node.decorators.remove(decorator) stats = prop.body.stats @@ -1421,7 +1918,7 @@ class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations): break else: stats.append(node) - return [] + return None @staticmethod def chain_decorators(node, decorators, name): @@ -1499,6 +1996,10 @@ class CnameDirectivesTransform(CythonTransform, SkipDeclarations): class ForwardDeclareTypes(CythonTransform): + """ + Declare all global cdef names that we allow referencing in other places, + before declaring everything (else) in source code order. + """ def visit_CompilerDirectivesNode(self, node): env = self.module_scope @@ -1542,6 +2043,14 @@ class ForwardDeclareTypes(CythonTransform): entry.type.get_all_specialized_function_types() return node + def visit_FuncDefNode(self, node): + # no traversal needed + return node + + def visit_PyClassDefNode(self, node): + # no traversal needed + return node + class AnalyseDeclarationsTransform(EnvTransform): @@ -1622,6 +2131,9 @@ if VALUE is not None: def visit_CClassDefNode(self, node): node = self.visit_ClassDefNode(node) + if node.scope and 'dataclasses.dataclass' in node.scope.directives: + from .Dataclass import handle_cclass_dataclass + handle_cclass_dataclass(node, node.scope.directives['dataclasses.dataclass'], self) if node.scope and node.scope.implemented and node.body: stats = [] for entry in node.scope.var_entries: @@ -1633,8 +2145,8 @@ if VALUE is not None: if stats: node.body.stats += stats if (node.visibility != 'extern' - and not node.scope.lookup('__reduce__') - and not node.scope.lookup('__reduce_ex__')): + and not node.scope.lookup('__reduce__') + and not node.scope.lookup('__reduce_ex__')): self._inject_pickle_methods(node) return node @@ -1688,9 +2200,9 @@ if VALUE is not None: pickle_func = TreeFragment(u""" def __reduce_cython__(self): - raise TypeError("%(msg)s") + raise TypeError, "%(msg)s" def __setstate_cython__(self, __pyx_state): - raise TypeError("%(msg)s") + raise TypeError, "%(msg)s" """ % {'msg': msg}, level='c_class', pipeline=[NormalizeTree(None)]).substitute({}) pickle_func.analyse_declarations(node.scope) @@ -1702,10 +2214,11 @@ if VALUE is not None: if not e.type.is_pyobject: e.type.create_to_py_utility_code(env) e.type.create_from_py_utility_code(env) + all_members_names = [e.name for e in all_members] checksums = _calculate_pickle_checksums(all_members_names) - unpickle_func_name = '__pyx_unpickle_%s' % node.class_name + unpickle_func_name = '__pyx_unpickle_%s' % node.punycode_class_name # TODO(robertwb): Move the state into the third argument # so it can be pickled *after* self is memoized. @@ -1715,7 +2228,7 @@ if VALUE is not None: cdef object __pyx_result if __pyx_checksum not in %(checksums)s: from pickle import PickleError as __pyx_PickleError - raise __pyx_PickleError("Incompatible checksums (0x%%x vs %(checksums)s = (%(members)s))" %% __pyx_checksum) + raise __pyx_PickleError, "Incompatible checksums (0x%%x vs %(checksums)s = (%(members)s))" %% __pyx_checksum __pyx_result = %(class_name)s.__new__(__pyx_type) if __pyx_state is not None: %(unpickle_func_name)s__set_state(<%(class_name)s> __pyx_result, __pyx_state) @@ -1783,8 +2296,8 @@ if VALUE is not None: for decorator in old_decorators: func = decorator.decorator if (not func.is_name or - func.name not in ('staticmethod', 'classmethod') or - env.lookup_here(func.name)): + func.name not in ('staticmethod', 'classmethod') or + env.lookup_here(func.name)): # not a static or classmethod decorators.append(decorator) @@ -1802,8 +2315,10 @@ if VALUE is not None: "Handle def or cpdef fused functions" # Create PyCFunction nodes for each specialization node.stats.insert(0, node.py_func) - node.py_func = self.visit(node.py_func) + self.visitchild(node, 'py_func') node.update_fused_defnode_entry(env) + # For the moment, fused functions do not support METH_FASTCALL + node.py_func.entry.signature.use_fastcall = False pycfunc = ExprNodes.PyCFunctionNode.from_defnode(node.py_func, binding=True) pycfunc = ExprNodes.ProxyNode(pycfunc.coerce_to_temp(env)) node.resulting_fused_function = pycfunc @@ -1846,19 +2361,6 @@ if VALUE is not None: return node - def _handle_nogil_cleanup(self, lenv, node): - "Handle cleanup for 'with gil' blocks in nogil functions." - if lenv.nogil and lenv.has_with_gil_block: - # Acquire the GIL for cleanup in 'nogil' functions, by wrapping - # the entire function body in try/finally. - # The corresponding release will be taken care of by - # Nodes.FuncDefNode.generate_function_definitions() - node.body = Nodes.NogilTryFinallyStatNode( - node.body.pos, - body=node.body, - finally_clause=Nodes.EnsureGILNode(node.body.pos), - finally_except_clause=Nodes.EnsureGILNode(node.body.pos)) - def _handle_fused(self, node): if node.is_generator and node.has_fused_arguments: node.has_fused_arguments = False @@ -1890,6 +2392,8 @@ if VALUE is not None: for var, type_node in node.directive_locals.items(): if not lenv.lookup_here(var): # don't redeclare args type = type_node.analyse_as_type(lenv) + if type and type.is_fused and lenv.fused_to_specific: + type = type.specialize(lenv.fused_to_specific) if type: lenv.declare_var(var, type, type_node.pos) else: @@ -1899,17 +2403,18 @@ if VALUE is not None: node = self._create_fused_function(env, node) else: node.body.analyse_declarations(lenv) - self._handle_nogil_cleanup(lenv, node) self._super_visit_FuncDefNode(node) self.seen_vars_stack.pop() + + if "ufunc" in lenv.directives: + from . import UFuncs + return UFuncs.convert_to_ufunc(node) return node def visit_DefNode(self, node): node = self.visit_FuncDefNode(node) env = self.current_env() - if isinstance(node, Nodes.DefNode) and node.is_wrapper: - env = env.parent_scope if (not isinstance(node, Nodes.DefNode) or node.fused_py_func or node.is_generator_body or not node.needs_assignment_synthesis(env)): @@ -1959,11 +2464,17 @@ if VALUE is not None: assmt.analyse_declarations(env) return assmt + def visit_func_outer_attrs(self, node): + # any names in the outer attrs should not be looked up in the function "seen_vars_stack" + stack = self.seen_vars_stack.pop() + super(AnalyseDeclarationsTransform, self).visit_func_outer_attrs(node) + self.seen_vars_stack.append(stack) + def visit_ScopedExprNode(self, node): env = self.current_env() node.analyse_declarations(env) # the node may or may not have a local scope - if node.has_local_scope: + if node.expr_scope: self.seen_vars_stack.append(set(self.seen_vars_stack[-1])) self.enter_scope(node, node.expr_scope) node.analyse_scoped_declarations(node.expr_scope) @@ -1971,6 +2482,7 @@ if VALUE is not None: self.exit_scope() self.seen_vars_stack.pop() else: + node.analyse_scoped_declarations(env) self.visitchildren(node) return node @@ -1993,7 +2505,7 @@ if VALUE is not None: # (so it can't happen later). # Note that we don't return the original node, as it is # never used after this phase. - if True: # private (default) + if True: # private (default) return None self_value = ExprNodes.AttributeNode( @@ -2085,8 +2597,8 @@ if VALUE is not None: if node.name in self.seen_vars_stack[-1]: entry = self.current_env().lookup(node.name) if (entry is None or entry.visibility != 'extern' - and not entry.scope.is_c_class_scope): - warning(node.pos, "cdef variable '%s' declared after it is used" % node.name, 2) + and not entry.scope.is_c_class_scope): + error(node.pos, "cdef variable '%s' declared after it is used" % node.name) self.visitchildren(node) return node @@ -2096,13 +2608,12 @@ if VALUE is not None: return None def visit_CnameDecoratorNode(self, node): - child_node = self.visit(node.node) + child_node = self.visitchild(node, 'node') if not child_node: return None - if type(child_node) is list: # Assignment synthesized - node.child_node = child_node[0] + if type(child_node) is list: # Assignment synthesized + node.node = child_node[0] return [node] + child_node[1:] - node.node = child_node return node def create_Property(self, entry): @@ -2122,6 +2633,11 @@ if VALUE is not None: property.doc = entry.doc return property + def visit_AssignmentExpressionNode(self, node): + self.visitchildren(node) + node.analyse_declarations(self.current_env()) + return node + def _calculate_pickle_checksums(member_names): # Cython 0.x used MD5 for the checksum, which a few Python installations remove for security reasons. @@ -2130,7 +2646,7 @@ def _calculate_pickle_checksums(member_names): member_names_string = ' '.join(member_names).encode('utf-8') hash_kwargs = {'usedforsecurity': False} if sys.version_info >= (3, 9) else {} checksums = [] - for algo_name in ['md5', 'sha256', 'sha1']: + for algo_name in ['sha256', 'sha1', 'md5']: try: mkchecksum = getattr(hashlib, algo_name) checksum = mkchecksum(member_names_string, **hash_kwargs).hexdigest() @@ -2219,7 +2735,7 @@ class CalculateQualifiedNamesTransform(EnvTransform): def visit_ClassDefNode(self, node): orig_qualified_name = self.qualified_name[:] entry = (getattr(node, 'entry', None) or # PyClass - self.current_env().lookup_here(node.name)) # CClass + self.current_env().lookup_here(node.target.name)) # CClass self._append_entry(entry) self._super_visit_ClassDefNode(node) self.qualified_name = orig_qualified_name @@ -2328,8 +2844,8 @@ class ExpandInplaceOperators(EnvTransform): operand2 = rhs, inplace=True) # Manually analyse types for new node. - lhs.analyse_target_types(env) - dup.analyse_types(env) + lhs = lhs.analyse_target_types(env) + dup.analyse_types(env) # FIXME: no need to reanalyse the copy, right? binop.analyse_operation(env) node = Nodes.SingleAssignmentNode( node.pos, @@ -2379,13 +2895,15 @@ class AdjustDefByDirectives(CythonTransform, SkipDeclarations): return_type_node = self.directives.get('returns') if return_type_node is None and self.directives['annotation_typing']: return_type_node = node.return_type_annotation - # for Python anntations, prefer safe exception handling by default + # for Python annotations, prefer safe exception handling by default if return_type_node is not None and except_val is None: except_val = (None, True) # except * elif except_val is None: - # backward compatible default: no exception check - except_val = (None, False) + # backward compatible default: no exception check, unless there's also a "@returns" declaration + except_val = (None, True if return_type_node else False) if 'ccall' in self.directives: + if 'cfunc' in self.directives: + error(node.pos, "cfunc and ccall directives cannot be combined") node = node.as_cfunction( overridable=True, modifiers=modifiers, nogil=nogil, returns=return_type_node, except_val=except_val) @@ -2437,8 +2955,6 @@ class AlignFunctionDefinitions(CythonTransform): def visit_ModuleNode(self, node): self.scope = node.scope - self.directives = node.directives - self.imported_names = set() # hack, see visit_FromImportStatNode() self.visitchildren(node) return node @@ -2476,15 +2992,45 @@ class AlignFunctionDefinitions(CythonTransform): error(pxd_def.pos, "previous declaration here") return None node = node.as_cfunction(pxd_def) - elif (self.scope.is_module_scope and self.directives['auto_cpdef'] - and not node.name in self.imported_names - and node.is_cdef_func_compatible()): - # FIXME: cpdef-ing should be done in analyse_declarations() - node = node.as_cfunction(scope=self.scope) # Enable this when nested cdef functions are allowed. # self.visitchildren(node) return node + def visit_ExprNode(self, node): + # ignore lambdas and everything else that appears in expressions + return node + + +class AutoCpdefFunctionDefinitions(CythonTransform): + + def visit_ModuleNode(self, node): + self.directives = node.directives + self.imported_names = set() # hack, see visit_FromImportStatNode() + self.scope = node.scope + self.visitchildren(node) + return node + + def visit_DefNode(self, node): + if (self.scope.is_module_scope and self.directives['auto_cpdef'] + and node.name not in self.imported_names + and node.is_cdef_func_compatible()): + # FIXME: cpdef-ing should be done in analyse_declarations() + node = node.as_cfunction(scope=self.scope) + return node + + def visit_CClassDefNode(self, node, pxd_def=None): + if pxd_def is None: + pxd_def = self.scope.lookup(node.class_name) + if pxd_def: + if not pxd_def.defined_in_pxd: + return node + outer_scope = self.scope + self.scope = pxd_def.type.scope + self.visitchildren(node) + if pxd_def: + self.scope = outer_scope + return node + def visit_FromImportStatNode(self, node): # hack to prevent conditional import fallback functions from # being cdpef-ed (global Python variables currently conflict @@ -2504,8 +3050,7 @@ class RemoveUnreachableCode(CythonTransform): if not self.current_directives['remove_unreachable']: return node self.visitchildren(node) - for idx, stat in enumerate(node.stats): - idx += 1 + for idx, stat in enumerate(node.stats, 1): if stat.is_terminator: if idx < len(node.stats): if self.current_directives['warn.unreachable']: @@ -2604,6 +3149,8 @@ class YieldNodeCollector(TreeVisitor): class MarkClosureVisitor(CythonTransform): + # In addition to marking closures this is also responsible to finding parts of the + # generator iterable and marking them def visit_ModuleNode(self, node): self.needs_closure = False @@ -2649,7 +3196,8 @@ class MarkClosureVisitor(CythonTransform): star_arg=node.star_arg, starstar_arg=node.starstar_arg, doc=node.doc, decorators=node.decorators, gbody=gbody, lambda_name=node.lambda_name, - return_type_annotation=node.return_type_annotation) + return_type_annotation=node.return_type_annotation, + is_generator_expression=node.is_generator_expression) return coroutine def visit_CFuncDefNode(self, node): @@ -2673,6 +3221,19 @@ class MarkClosureVisitor(CythonTransform): self.needs_closure = True return node + def visit_GeneratorExpressionNode(self, node): + node = self.visit_LambdaNode(node) + if not isinstance(node.loop, Nodes._ForInStatNode): + # Possibly should handle ForFromStatNode + # but for now do nothing + return node + itseq = node.loop.iterator.sequence + # literals do not need replacing with an argument + if itseq.is_literal: + return node + _GeneratorExpressionArgumentsMarker(node).visit(itseq) + return node + class CreateClosureClasses(CythonTransform): # Output closure classes in module scope for all functions @@ -2726,7 +3287,7 @@ class CreateClosureClasses(CythonTransform): if not node.py_cfunc_node: raise InternalError("DefNode does not have assignment node") inner_node = node.py_cfunc_node - inner_node.needs_self_code = False + inner_node.needs_closure_code = False node.needs_outer_scope = False if node.is_generator: @@ -2745,6 +3306,7 @@ class CreateClosureClasses(CythonTransform): as_name = '%s_%s' % ( target_module_scope.next_id(Naming.closure_class_prefix), node.entry.cname.replace('.','__')) + as_name = EncodedString(as_name) entry = target_module_scope.declare_c_class( name=as_name, pos=node.pos, defining=True, @@ -2816,6 +3378,10 @@ class CreateClosureClasses(CythonTransform): self.visitchildren(node) return node + def visit_GeneratorExpressionNode(self, node): + node = _HandleGeneratorArguments()(node) + return self.visit_LambdaNode(node) + class InjectGilHandling(VisitorTransform, SkipDeclarations): """ @@ -2824,20 +3390,20 @@ class InjectGilHandling(VisitorTransform, SkipDeclarations): Must run before the AnalyseDeclarationsTransform to make sure the GILStatNodes get set up, parallel sections know that the GIL is acquired inside of them, etc. """ - def __call__(self, root): - self.nogil = False - return super(InjectGilHandling, self).__call__(root) + nogil = False # special node handling - def visit_RaiseStatNode(self, node): - """Allow raising exceptions in nogil sections by wrapping them in a 'with gil' block.""" + def _inject_gil_in_nogil(self, node): + """Allow the (Python statement) node in nogil sections by wrapping it in a 'with gil' block.""" if self.nogil: node = Nodes.GILStatNode(node.pos, state='gil', body=node) return node + visit_RaiseStatNode = _inject_gil_in_nogil + visit_PrintStatNode = _inject_gil_in_nogil # sadly, not the function + # further candidates: - # def visit_AssertStatNode(self, node): # def visit_ReraiseStatNode(self, node): # nogil tracking @@ -2905,6 +3471,7 @@ class GilCheck(VisitorTransform): self.env_stack.append(node.local_scope) inner_nogil = node.local_scope.nogil + nogil_declarator_only = self.nogil_declarator_only if inner_nogil: self.nogil_declarator_only = True @@ -2913,13 +3480,20 @@ class GilCheck(VisitorTransform): self._visit_scoped_children(node, inner_nogil) - # This cannot be nested, so it doesn't need backup/restore - self.nogil_declarator_only = False + # FuncDefNodes can be nested, because a cpdef function contains a def function + # inside it. Therefore restore to previous state + self.nogil_declarator_only = nogil_declarator_only self.env_stack.pop() return node def visit_GILStatNode(self, node): + if node.condition is not None: + error(node.condition.pos, + "Non-constant condition in a " + "`with %s(<condition>)` statement" % node.state) + return node + if self.nogil and node.nogil_check: node.nogil_check() @@ -2933,6 +3507,8 @@ class GilCheck(VisitorTransform): else: error(node.pos, "Trying to release the GIL while it was " "previously released.") + if self.nogil_declarator_only: + node.scope_gil_state_known = False if isinstance(node.finally_clause, Nodes.StatListNode): # The finally clause of the GILStatNode is a GILExitNode, @@ -2983,6 +3559,12 @@ class GilCheck(VisitorTransform): self.visitchildren(node) return node + def visit_GILExitNode(self, node): + if self.nogil_declarator_only: + node.scope_gil_state_known = False + self.visitchildren(node) + return node + def visit_Node(self, node): if self.env_stack and self.nogil and node.nogil_check: node.nogil_check(self.env_stack[-1]) @@ -2995,6 +3577,32 @@ class GilCheck(VisitorTransform): return node +class CoerceCppTemps(EnvTransform, SkipDeclarations): + """ + For temporary expression that are implemented using std::optional it's necessary the temps are + assigned using `__pyx_t_x = value;` but accessed using `something = (*__pyx_t_x)`. This transform + inserts a coercion node to take care of this, and runs absolutely last (once nothing else can be + inserted into the tree) + + TODO: a possible alternative would be to split ExprNode.result() into ExprNode.rhs_rhs() and ExprNode.lhs_rhs()??? + """ + def visit_ModuleNode(self, node): + if self.current_env().cpp: + # skipping this makes it essentially free for C files + self.visitchildren(node) + return node + + def visit_ExprNode(self, node): + self.visitchildren(node) + if (self.current_env().directives['cpp_locals'] and + node.is_temp and node.type.is_cpp_class and + # Fake references are not replaced with "std::optional()". + not node.type.is_fake_reference): + node = ExprNodes.CppOptionalTempCoercion(node) + + return node + + class TransformBuiltinMethods(EnvTransform): """ Replace Cython's own cython.* builtins by the corresponding tree nodes. @@ -3017,9 +3625,7 @@ class TransformBuiltinMethods(EnvTransform): def visit_cython_attribute(self, node): attribute = node.as_cython_attribute() if attribute: - if attribute == u'compiled': - node = ExprNodes.BoolNode(node.pos, value=True) - elif attribute == u'__version__': + if attribute == u'__version__': from .. import __version__ as version node = ExprNodes.StringNode(node.pos, value=EncodedString(version)) elif attribute == u'NULL': @@ -3064,9 +3670,9 @@ class TransformBuiltinMethods(EnvTransform): error(self.pos, "Builtin 'vars()' called with wrong number of args, expected 0-1, got %d" % len(node.args)) if len(node.args) > 0: - return node # nothing to do + return node # nothing to do return ExprNodes.LocalsExprNode(pos, self.current_scope_node(), lenv) - else: # dir() + else: # dir() if len(node.args) > 1: error(self.pos, "Builtin 'dir()' called with wrong number of args, expected 0-1, got %d" % len(node.args)) @@ -3101,8 +3707,8 @@ class TransformBuiltinMethods(EnvTransform): def _inject_eval(self, node, func_name): lenv = self.current_env() - entry = lenv.lookup_here(func_name) - if entry or len(node.args) != 1: + entry = lenv.lookup(func_name) + if len(node.args) != 1 or (entry and not entry.is_builtin): return node # Inject globals and locals node.args.append(ExprNodes.GlobalsExprNode(node.pos)) @@ -3119,8 +3725,7 @@ class TransformBuiltinMethods(EnvTransform): return node # Inject no-args super def_node = self.current_scope_node() - if (not isinstance(def_node, Nodes.DefNode) or not def_node.args or - len(self.env_stack) < 2): + if not isinstance(def_node, Nodes.DefNode) or not def_node.args or len(self.env_stack) < 2: return node class_node, class_scope = self.env_stack[-2] if class_scope.is_py_class_scope: @@ -3259,10 +3864,17 @@ class ReplaceFusedTypeChecks(VisitorTransform): self.visitchildren(node) return self.transform(node) + def visit_GILStatNode(self, node): + """ + Fold constant condition of GILStatNode. + """ + self.visitchildren(node) + return self.transform(node) + def visit_PrimaryCmpNode(self, node): with Errors.local_errors(ignore=True): - type1 = node.operand1.analyse_as_type(self.local_scope) - type2 = node.operand2.analyse_as_type(self.local_scope) + type1 = node.operand1.analyse_as_type(self.local_scope) + type2 = node.operand2.analyse_as_type(self.local_scope) if type1 and type2: false_node = ExprNodes.BoolNode(node.pos, value=False) @@ -3398,9 +4010,14 @@ class DebugTransform(CythonTransform): else: pf_cname = node.py_func.entry.func_cname + # For functions defined using def, cname will be pyfunc_cname=__pyx_pf_* + # For functions defined using cpdef or cdef, cname will be func_cname=__pyx_f_* + # In all cases, cname will be the name of the function containing the actual code + cname = node.entry.pyfunc_cname or node.entry.func_cname + attrs = dict( name=node.entry.name or getattr(node, 'name', '<unknown>'), - cname=node.entry.func_cname, + cname=cname, pf_cname=pf_cname, qualified_name=node.local_scope.qualified_name, lineno=str(node.pos[1])) @@ -3428,10 +4045,10 @@ class DebugTransform(CythonTransform): def visit_NameNode(self, node): if (self.register_stepinto and - node.type is not None and - node.type.is_cfunction and - getattr(node, 'is_called', False) and - node.entry.func_cname is not None): + node.type is not None and + node.type.is_cfunction and + getattr(node, 'is_called', False) and + node.entry.func_cname is not None): # don't check node.entry.in_cinclude, as 'cdef extern: ...' # declared functions are not 'in_cinclude'. # This means we will list called 'cdef' functions as @@ -3450,26 +4067,16 @@ class DebugTransform(CythonTransform): it's a "relevant frame" and it will know where to set the breakpoint for 'break modulename'. """ - name = node.full_module_name.rpartition('.')[-1] - - cname_py2 = 'init' + name - cname_py3 = 'PyInit_' + name - - py2_attrs = dict( - name=name, - cname=cname_py2, + self._serialize_modulenode_as_function(node, dict( + name=node.full_module_name.rpartition('.')[-1], + cname=node.module_init_func_cname(), pf_cname='', # Ignore the qualified_name, breakpoints should be set using # `cy break modulename:lineno` for module-level breakpoints. qualified_name='', lineno='1', is_initmodule_function="True", - ) - - py3_attrs = dict(py2_attrs, cname=cname_py3) - - self._serialize_modulenode_as_function(node, py2_attrs) - self._serialize_modulenode_as_function(node, py3_attrs) + )) def _serialize_modulenode_as_function(self, node, attrs): self.tb.start('Function', attrs=attrs) |