diff options
Diffstat (limited to 'numpy/lib/utils.py')
-rw-r--r-- | numpy/lib/utils.py | 144 |
1 files changed, 46 insertions, 98 deletions
diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py index df0052493..4a68dde18 100644 --- a/numpy/lib/utils.py +++ b/numpy/lib/utils.py @@ -1010,103 +1010,59 @@ class SafeEval(object): """ - if sys.version_info[0] < 3: - def visit(self, node, **kw): - cls = node.__class__ - meth = getattr(self, 'visit'+cls.__name__, self.default) - return meth(node, **kw) - - def default(self, node, **kw): - raise SyntaxError("Unsupported source construct: %s" - % node.__class__) - - def visitExpression(self, node, **kw): - for child in node.getChildNodes(): - return self.visit(child, **kw) - - def visitConst(self, node, **kw): - return node.value - - def visitDict(self, node,**kw): - return dict( - [(self.visit(k), self.visit(v)) for k, v in node.items] - ) - - def visitTuple(self, node, **kw): - return tuple([self.visit(i) for i in node.nodes]) - - def visitList(self, node, **kw): - return [self.visit(i) for i in node.nodes] - - def visitUnaryAdd(self, node, **kw): - return +self.visit(node.getChildNodes()[0]) - - def visitUnarySub(self, node, **kw): - return -self.visit(node.getChildNodes()[0]) - - def visitName(self, node, **kw): - if node.name == 'False': - return False - elif node.name == 'True': - return True - elif node.name == 'None': - return None - else: - raise SyntaxError("Unknown name: %s" % node.name) - else: + def visit(self, node): + cls = node.__class__ + meth = getattr(self, 'visit' + cls.__name__, self.default) + return meth(node) - def visit(self, node): - cls = node.__class__ - meth = getattr(self, 'visit' + cls.__name__, self.default) - return meth(node) + def default(self, node): + raise SyntaxError("Unsupported source construct: %s" + % node.__class__) - def default(self, node): - raise SyntaxError("Unsupported source construct: %s" - % node.__class__) + def visitExpression(self, node): + return self.visit(node.body) - def visitExpression(self, node): - return self.visit(node.body) + def visitNum(self, node): + return node.n - def visitNum(self, node): - return node.n + def visitStr(self, node): + return node.s - def visitStr(self, node): - return node.s + def visitBytes(self, node): + return node.s - def visitBytes(self, node): - return node.s + def visitDict(self, node,**kw): + return dict([(self.visit(k), self.visit(v)) + for k, v in zip(node.keys, node.values)]) - def visitDict(self, node,**kw): - return dict([(self.visit(k), self.visit(v)) - for k, v in zip(node.keys, node.values)]) + def visitTuple(self, node): + return tuple([self.visit(i) for i in node.elts]) - def visitTuple(self, node): - return tuple([self.visit(i) for i in node.elts]) + def visitList(self, node): + return [self.visit(i) for i in node.elts] - def visitList(self, node): - return [self.visit(i) for i in node.elts] + def visitUnaryOp(self, node): + import ast + if isinstance(node.op, ast.UAdd): + return +self.visit(node.operand) + elif isinstance(node.op, ast.USub): + return -self.visit(node.operand) + else: + raise SyntaxError("Unknown unary op: %r" % node.op) + + def visitName(self, node): + if node.id == 'False': + return False + elif node.id == 'True': + return True + elif node.id == 'None': + return None + else: + raise SyntaxError("Unknown name: %s" % node.id) - def visitUnaryOp(self, node): - import ast - if isinstance(node.op, ast.UAdd): - return +self.visit(node.operand) - elif isinstance(node.op, ast.USub): - return -self.visit(node.operand) - else: - raise SyntaxError("Unknown unary op: %r" % node.op) - - def visitName(self, node): - if node.id == 'False': - return False - elif node.id == 'True': - return True - elif node.id == 'None': - return None - else: - raise SyntaxError("Unknown name: %s" % node.id) + def visitNameConstant(self, node): + return node.value - def visitNameConstant(self, node): - return node.value def safe_eval(source): """ @@ -1148,28 +1104,20 @@ def safe_eval(source): >>> np.safe_eval('open("/home/user/.ssh/id_dsa").read()') Traceback (most recent call last): ... - SyntaxError: Unsupported source construct: compiler.ast.CallFunc + SyntaxError: Unsupported source construct: <class '_ast.Call'> """ # Local imports to speed up numpy's import time. import warnings - - with warnings.catch_warnings(): - # compiler package is deprecated for 3.x, which is already solved - # here - warnings.simplefilter('ignore', DeprecationWarning) - try: - import compiler - except ImportError: - import ast as compiler + import ast walker = SafeEval() try: - ast = compiler.parse(source, mode="eval") + res = ast.parse(source, mode="eval") except SyntaxError: raise try: - return walker.visit(ast) + return walker.visit(res) except SyntaxError: raise |