Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to recursively simplify a mathematical expression with AST in python3?

I have this mathematical expression:

tree = ast.parse('1 + 2 + 3 + x')

which corresponds to this abstract syntax tree:

Module(body=[Expr(value=BinOp(left=BinOp(left=BinOp(left=Num(n=1), op=Add(), right=Num(n=2)), op=Add(), right=Num(n=3)), op=Add(), right=Name(id='x', ctx=Load())))])

and I would like to simplify it - that is, get this:

Module(body=[Expr(value=BinOp(left=Num(n=6), op=Add(), right=Name(id='x', ctx=Load())))])

According to the documentation, I should use the NodeTransformer class. A suggestion in the docs says the following:

Keep in mind that if the node you’re operating on has child nodes you must either transform the child nodes yourself or call the generic_visit() method for the node first.

I tried implementing my own transformer:

class Evaluator(ast.NodeTransformer):
    def visit_BinOp(self, node):
        print('Evaluating ', ast.dump(node))
        for child in ast.iter_child_nodes(node):
            self.visit(child)

        if type(node.left) == ast.Num and type(node.right) == ast.Num:
            print(ast.literal_eval(node))
            return ast.copy_location(ast.Subscript(value=ast.literal_eval(node)), node)
        else:
            return node

What it should do in this specific case is simplify 1+2 into 3, and then 3 +3 into 6. It does simplify the binary operations I want to simplify, but it doesn't update the original Syntax Tree. I tried different approaches but I still don't get how I can recursively simplify all binary operations (in a depth-first manner). Could anyone point me in the right direction?

Thank you.

like image 955
Guillermo Mosse Avatar asked Oct 25 '25 02:10

Guillermo Mosse


1 Answers

There are three possible return values for the visit_* methods:

  1. None which means the node will be deleted,
  2. node (the node itself) which means no change will be applied,
  3. A new node, which will replace the old one.

So when you want to replace the BinOp with a Num you need to return a new Num node. The evaluation of the expression cannot be done via ast.literal_eval as this function only evaluates literals (not arbitrary expressions). Instead you can use eval for example.

So you could use the following node transformer class:

import ast

class Evaluator(ast.NodeTransformer):
    ops = {
        ast.Add: '+',
        ast.Sub: '-',
        ast.Mult: '*',
        ast.Div: '/',
        # define more here
    }

    def visit_BinOp(self, node):
        self.generic_visit(node)
        if isinstance(node.left, ast.Num) and isinstance(node.right, ast.Num):
            # On Python <= 3.6 you can use ast.literal_eval.
            # value = ast.literal_eval(node)
            value = eval(f'{node.left.n} {self.ops[type(node.op)]} {node.right.n}')
            return ast.Num(n=value)
        return node

tree = ast.parse('1 + 2 + 3 + x')
tree = ast.fix_missing_locations(Evaluator().visit(tree))
print(ast.dump(tree))
like image 184
a_guest Avatar answered Oct 27 '25 00:10

a_guest



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!