I have this mathematical expression:
tree = ast.parse('1 + 2 + 3 + x')
which corresponds to this abstract syntax tree:
Module(body=[Expr(value=BinOp(left=BinOp(left=BinOp(left=Num(n=1), op=Add(), right=Num(n=2)), op=Add(), right=Num(n=3)), op=Add(), right=Name(id='x', ctx=Load())))])
and I would like to simplify it - that is, get this:
Module(body=[Expr(value=BinOp(left=Num(n=6), op=Add(), right=Name(id='x', ctx=Load())))])
According to the documentation, I should use the NodeTransformer class. A suggestion in the docs says the following:
Keep in mind that if the node you’re operating on has child nodes you must either transform the child nodes yourself or call the generic_visit() method for the node first.
I tried implementing my own transformer:
class Evaluator(ast.NodeTransformer):
def visit_BinOp(self, node):
print('Evaluating ', ast.dump(node))
for child in ast.iter_child_nodes(node):
self.visit(child)
if type(node.left) == ast.Num and type(node.right) == ast.Num:
print(ast.literal_eval(node))
return ast.copy_location(ast.Subscript(value=ast.literal_eval(node)), node)
else:
return node
What it should do in this specific case is simplify 1+2 into 3, and then 3 +3 into 6. It does simplify the binary operations I want to simplify, but it doesn't update the original Syntax Tree. I tried different approaches but I still don't get how I can recursively simplify all binary operations (in a depth-first manner). Could anyone point me in the right direction?
Thank you.
There are three possible return values for the visit_* methods:
None which means the node will be deleted,node (the node itself) which means no change will be applied,So when you want to replace the BinOp with a Num you need to return a new Num node. The evaluation of the expression cannot be done via ast.literal_eval as this function only evaluates literals (not arbitrary expressions). Instead you can use eval for example.
So you could use the following node transformer class:
import ast
class Evaluator(ast.NodeTransformer):
ops = {
ast.Add: '+',
ast.Sub: '-',
ast.Mult: '*',
ast.Div: '/',
# define more here
}
def visit_BinOp(self, node):
self.generic_visit(node)
if isinstance(node.left, ast.Num) and isinstance(node.right, ast.Num):
# On Python <= 3.6 you can use ast.literal_eval.
# value = ast.literal_eval(node)
value = eval(f'{node.left.n} {self.ops[type(node.op)]} {node.right.n}')
return ast.Num(n=value)
return node
tree = ast.parse('1 + 2 + 3 + x')
tree = ast.fix_missing_locations(Evaluator().visit(tree))
print(ast.dump(tree))
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With