diff options
author | Pauli Virtanen <pav@iki.fi> | 2010-02-20 18:15:34 +0000 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2010-02-20 18:15:34 +0000 |
commit | ff26cec7eba491cf4cf48542b21f44932baf9572 (patch) | |
tree | c563e12a5deaea6302d2f410bc97909dc689434b /numpy/lib/utils.py | |
parent | 348f725101ef97f538e4652c66b54d5633d4be4d (diff) | |
download | numpy-ff26cec7eba491cf4cf48542b21f44932baf9572.tar.gz |
3K: lib: adapt safe_eval for Py3 ast module
Diffstat (limited to 'numpy/lib/utils.py')
-rw-r--r-- | numpy/lib/utils.py | 120 |
1 files changed, 88 insertions, 32 deletions
diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py index 3e73c2a0f..1248c7d05 100644 --- a/numpy/lib/utils.py +++ b/numpy/lib/utils.py @@ -991,45 +991,98 @@ class SafeEval(object): """ - def visit(self, node, **kw): - cls = node.__class__ - meth = getattr(self,'visit'+cls.__name__,self.default) - return meth(node, **kw) + if sys.version_info[0] < 3: + def visit(self, node, **kw): + cls = node.__class__ + meth = getattr(self,'visit'+cls.__name__,self.default) + return meth(node, **kw) - def default(self, node, **kw): - raise SyntaxError("Unsupported source construct: %s" % node.__class__) + def default(self, node, **kw): + raise SyntaxError("Unsupported source construct: %s" + % node.__class__) - def visitExpression(self, node, **kw): - for child in node.getChildNodes(): - return self.visit(child, **kw) + def visitExpression(self, node, **kw): + for child in node.getChildNodes(): + return self.visit(child, **kw) - def visitConst(self, node, **kw): - return node.value + def visitConst(self, node, **kw): + return node.value - def visitDict(self, node,**kw): - return dict([(self.visit(k),self.visit(v)) for k,v in node.items]) + def visitDict(self, node,**kw): + return dict([(self.visit(k),self.visit(v)) for k,v in node.items]) - def visitTuple(self, node, **kw): - return tuple([self.visit(i) for i in node.nodes]) + def visitTuple(self, node, **kw): + return tuple([self.visit(i) for i in node.nodes]) - def visitList(self, node, **kw): - return [self.visit(i) for i in node.nodes] + def visitList(self, node, **kw): + return [self.visit(i) for i in node.nodes] - def visitUnaryAdd(self, node, **kw): - return +self.visit(node.getChildNodes()[0]) + def visitUnaryAdd(self, node, **kw): + return +self.visit(node.getChildNodes()[0]) - def visitUnarySub(self, node, **kw): - return -self.visit(node.getChildNodes()[0]) + def visitUnarySub(self, node, **kw): + return -self.visit(node.getChildNodes()[0]) - def visitName(self, node, **kw): - if node.name == 'False': - return False - elif node.name == 'True': - return True - elif node.name == 'None': - return None - else: - raise SyntaxError("Unknown name: %s" % node.name) + def visitName(self, node, **kw): + if node.name == 'False': + return False + elif node.name == 'True': + return True + elif node.name == 'None': + return None + else: + raise SyntaxError("Unknown name: %s" % node.name) + else: + + def visit(self, node): + cls = node.__class__ + meth = getattr(self, 'visit' + cls.__name__, self.default) + return meth(node) + + def default(self, node): + raise SyntaxError("Unsupported source construct: %s" + % node.__class__) + + def visitExpression(self, node): + return self.visit(node.body) + + def visitNum(self, node): + return node.n + + def visitStr(self, node): + return node.s + + def visitBytes(self, node): + return node.s + + def visitDict(self, node,**kw): + return dict([(self.visit(k), self.visit(v)) + for k, v in zip(node.keys, node.values)]) + + def visitTuple(self, node): + return tuple([self.visit(i) for i in node.elts]) + + def visitList(self, node): + return [self.visit(i) for i in node.elts] + + def visitUnaryOp(self, node): + import ast + if isinstance(node.op, ast.UAdd): + return +self.visit(node.operand) + elif isinstance(node.op, ast.USub): + return -self.visit(node.operand) + else: + raise SyntaxError("Unknown unary op: %r" % node.op) + + def visitName(self, node): + if node.id == 'False': + return False + elif node.id == 'True': + return True + elif node.id == 'None': + return None + else: + raise SyntaxError("Unknown name: %s" % node.id) def safe_eval(source): """ @@ -1075,10 +1128,13 @@ def safe_eval(source): """ # Local import to speed up numpy's import time. - import compiler + try: + import compiler + except ImportError: + import ast as compiler walker = SafeEval() try: - ast = compiler.parse(source, "eval") + ast = compiler.parse(source, mode="eval") except SyntaxError, err: raise try: |