summaryrefslogtreecommitdiff
path: root/numpy/lib/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/utils.py')
-rw-r--r--numpy/lib/utils.py144
1 files changed, 46 insertions, 98 deletions
diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py
index df0052493..4a68dde18 100644
--- a/numpy/lib/utils.py
+++ b/numpy/lib/utils.py
@@ -1010,103 +1010,59 @@ class SafeEval(object):
"""
- if sys.version_info[0] < 3:
- def visit(self, node, **kw):
- cls = node.__class__
- meth = getattr(self, 'visit'+cls.__name__, self.default)
- return meth(node, **kw)
-
- def default(self, node, **kw):
- raise SyntaxError("Unsupported source construct: %s"
- % node.__class__)
-
- def visitExpression(self, node, **kw):
- for child in node.getChildNodes():
- return self.visit(child, **kw)
-
- def visitConst(self, node, **kw):
- return node.value
-
- def visitDict(self, node,**kw):
- return dict(
- [(self.visit(k), self.visit(v)) for k, v in node.items]
- )
-
- def visitTuple(self, node, **kw):
- return tuple([self.visit(i) for i in node.nodes])
-
- def visitList(self, node, **kw):
- return [self.visit(i) for i in node.nodes]
-
- def visitUnaryAdd(self, node, **kw):
- return +self.visit(node.getChildNodes()[0])
-
- def visitUnarySub(self, node, **kw):
- return -self.visit(node.getChildNodes()[0])
-
- def visitName(self, node, **kw):
- if node.name == 'False':
- return False
- elif node.name == 'True':
- return True
- elif node.name == 'None':
- return None
- else:
- raise SyntaxError("Unknown name: %s" % node.name)
- else:
+ def visit(self, node):
+ cls = node.__class__
+ meth = getattr(self, 'visit' + cls.__name__, self.default)
+ return meth(node)
- def visit(self, node):
- cls = node.__class__
- meth = getattr(self, 'visit' + cls.__name__, self.default)
- return meth(node)
+ def default(self, node):
+ raise SyntaxError("Unsupported source construct: %s"
+ % node.__class__)
- def default(self, node):
- raise SyntaxError("Unsupported source construct: %s"
- % node.__class__)
+ def visitExpression(self, node):
+ return self.visit(node.body)
- def visitExpression(self, node):
- return self.visit(node.body)
+ def visitNum(self, node):
+ return node.n
- def visitNum(self, node):
- return node.n
+ def visitStr(self, node):
+ return node.s
- def visitStr(self, node):
- return node.s
+ def visitBytes(self, node):
+ return node.s
- def visitBytes(self, node):
- return node.s
+ def visitDict(self, node,**kw):
+ return dict([(self.visit(k), self.visit(v))
+ for k, v in zip(node.keys, node.values)])
- def visitDict(self, node,**kw):
- return dict([(self.visit(k), self.visit(v))
- for k, v in zip(node.keys, node.values)])
+ def visitTuple(self, node):
+ return tuple([self.visit(i) for i in node.elts])
- def visitTuple(self, node):
- return tuple([self.visit(i) for i in node.elts])
+ def visitList(self, node):
+ return [self.visit(i) for i in node.elts]
- def visitList(self, node):
- return [self.visit(i) for i in node.elts]
+ def visitUnaryOp(self, node):
+ import ast
+ if isinstance(node.op, ast.UAdd):
+ return +self.visit(node.operand)
+ elif isinstance(node.op, ast.USub):
+ return -self.visit(node.operand)
+ else:
+ raise SyntaxError("Unknown unary op: %r" % node.op)
+
+ def visitName(self, node):
+ if node.id == 'False':
+ return False
+ elif node.id == 'True':
+ return True
+ elif node.id == 'None':
+ return None
+ else:
+ raise SyntaxError("Unknown name: %s" % node.id)
- def visitUnaryOp(self, node):
- import ast
- if isinstance(node.op, ast.UAdd):
- return +self.visit(node.operand)
- elif isinstance(node.op, ast.USub):
- return -self.visit(node.operand)
- else:
- raise SyntaxError("Unknown unary op: %r" % node.op)
-
- def visitName(self, node):
- if node.id == 'False':
- return False
- elif node.id == 'True':
- return True
- elif node.id == 'None':
- return None
- else:
- raise SyntaxError("Unknown name: %s" % node.id)
+ def visitNameConstant(self, node):
+ return node.value
- def visitNameConstant(self, node):
- return node.value
def safe_eval(source):
"""
@@ -1148,28 +1104,20 @@ def safe_eval(source):
>>> np.safe_eval('open("/home/user/.ssh/id_dsa").read()')
Traceback (most recent call last):
...
- SyntaxError: Unsupported source construct: compiler.ast.CallFunc
+ SyntaxError: Unsupported source construct: <class '_ast.Call'>
"""
# Local imports to speed up numpy's import time.
import warnings
-
- with warnings.catch_warnings():
- # compiler package is deprecated for 3.x, which is already solved
- # here
- warnings.simplefilter('ignore', DeprecationWarning)
- try:
- import compiler
- except ImportError:
- import ast as compiler
+ import ast
walker = SafeEval()
try:
- ast = compiler.parse(source, mode="eval")
+ res = ast.parse(source, mode="eval")
except SyntaxError:
raise
try:
- return walker.visit(ast)
+ return walker.visit(res)
except SyntaxError:
raise