Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Parsing an equation with custom functions in Python

I have a string that is a mathematical equation, but with some custom functions. I need to find all such functions and replace them with some code.

For example, I have a string:

a+b+f1(f2(x,y),x)

I want code that will replace (say) f2(x,y) with x+y^2 and f1(x,y) with sin(x+y).

It would be ideal if nested functions were supported, like in the example. However, it would still be useful if nesting was not supported.

As I understand from similar topics this can be done using a compiler module like compiler.parse(eq). How I can work with AST object created by compiler.parse(eq) to reconstruct my string back, replacing all found functions?

I need only to perform substitution and then string will be used in other program. Evaluation is not needed.

like image 844
Alex Avatar asked Jan 15 '15 10:01

Alex


2 Answers

Here is a minimal working example (+, - , *, /, ** binary and unary operations and function call implemented). The priority of operations are set with parenthesis.

A little bit more than the functionality for the example given is done:

from __future__ import print_function
import ast

def transform(eq,functions):
    class EqVisitor(ast.NodeVisitor):
        def visit_BinOp(self,node):
            #generate("=>BinOp")
            generate("(")
            self.visit(node.left)
            self.visit(node.op)
            #generate("ici",str(node.op),node._fields,node._attributes)
            #generate(dir(node.op))
            self.visit(node.right)
            generate(")")
            #ast.NodeVisitor.generic_visit(self,node)
        def visit_USub(self,node):
            generate("-")
        def visit_UAdd(self,node):
            generate("+")

        def visit_Sub(self,node):
            generate("-")
        def visit_Add(self,node):
            generate("+")
        def visit_Pow(self,node):
            generate("**")
        def visit_Mult(self,node):
            generate("*")
        def visit_Div(self,node):
            generate("/")
        def visit_Name(self,node):
            generate(node.id)
        def visit_Call(self,node):
            debug("function",node.func.id)
            if node.func.id in functions:
                debug("defined function")
                func_visit(functions[node.func.id],node.args)
                return
            debug("not defined function",node.func.id)
            #generate(node._fields)
            #generate("args")
            generate(node.func.id)
            generate("(")
            sep = ""
            for arg in node.args:
                generate (sep)
                self.visit(arg)
                sep=","
            generate(")")
        def visit_Num(self,node):
            generate(node.n)
        def generic_visit(self, node):


            debug ("\n",type(node).__name__)
            debug (node._fields)
            ast.NodeVisitor.generic_visit(self, node)

    def func_visit(definition,concrete_args):
        class FuncVisitor(EqVisitor):
            def visit_arguments(self,node):
                #generate("visit arguments")
                #generate(node._fields)
                self.arguments={}
                for concrete_arg,formal_arg in zip(concrete_args,node.args):
                    #generate(formal_arg._fields)
                    self.arguments[formal_arg.id]=concrete_arg
                debug(self.arguments)
            def visit_Name(self,node):
                debug("visit Name",node.id)
                if node.id in self.arguments:
                    eqV.visit(self.arguments[node.id])
                else:
                    generate(node.id)


        funcV=FuncVisitor()
        funcV.visit(ast.parse(definition))

    eqV=EqVisitor()
    result = []
    def generate(s):
        #following line maybe usefull for debug
        debug(str(s))
        result.append(str(s))
    eqV.visit(ast.parse(eq,mode="eval"))
    return "".join(result)
def debug(*args,**kwargs):
    #print(*args,**kwargs)
    pass

Usage:

functions= {
    "f1":"def f1(x,y):return x+y**2",
    "f2":"def f2(x,y):return sin(x+y)",
}
eq="-(a+b)+f1(f2(+x,y),z)*4/365.12-h"
print(transform(eq,functions))

Result

((-(a+b)+(((sin((+x+y))+(z**2))*4)/365.12))-h)

WARNING

The code works with Python 2.7 and as it is AST dependent is not guaranteed to work with another version of Python. The Python 3 version doesn't work.

like image 114
Xavier Combelle Avatar answered Oct 20 '22 03:10

Xavier Combelle


The full substitution is quite tricky. Here is my attempt to do it. Here we can successfully inline expressions, but not in all scenarios. This code works on AST only, made by ast module. And uses codegen to stringify it back to code. The stringifying of ast and modifying ast in general is covered in other SO Q/A: "Parse a .py file, read the AST, modify it, then write back the modified source code".

First we define few helpers:

import ast
import codegen
import copy

def parseExpr(expr):
    # Strip:
    # Module(body=[Expr(value=
    return ast.parse(expr).body[0].value

def toSource(expr):
    return codegen.to_source(expr)

After that we define a substitution function using NodeTransformer. For example:

substitute(parseExpr("a + b"), { "a": parseExpr("1") }) # 1 + b

The simulatenous substitution of multiple variables is needed to properly avoid nasty situations. For example substituting both a and b for a + b in a + b. The result should be (a + b) + (a + b), but if we substitute first a for a + b, we'll get (a + b) + b, and then substitute b, we'll get (a + (a + b)) + b which is the wrong result! So simultaneous is important:

class NameTransformer(ast.NodeTransformer):
    def __init__(self, names):
        self.names = names

    def visit_Name(self, node):
        if node.id in self.names:
            return self.names[node.id]
        else:
            return node

def substitute(expr, names):
    print "substitute"
    for varName, varValue in names.iteritems():
        print "  name " + varName + " for " + toSource(varValue)
    print "  in " + toSource(expr)
    return NameTransformer(names).visit(expr)

Then we write similar NodeTransformer to find calls, where we can inline function definitions:

class CallTransformer(ast.NodeTransformer):
    def __init__(self, fnName, varNames, fnExpr):
        self.fnName = fnName
        self.varNames = varNames
        # substitute in new fn expr for each CallTransformer
        self.fnExpr = copy.deepcopy(fnExpr)
        self.modified = False

    def visit_Call(self, node):
        if (node.func.id == self.fnName):
            if len(node.args) == len(self.varNames):
                print "expand call to " + self.fnName + "(" + (", ".join(self.varNames)) + ")" + " with arguments "+ ", ".join(map(toSource, node.args))
                # We substitute in args too!
                old_node = node
                args = map(self.visit, node.args)
                names = dict(zip(self.varNames, args))
                node = substitute(self.fnExpr, names)
                self.modified = True
                return node
            else:
                raise Exception("invalid arity " + toSource(node))
        else:
            return self.generic_visit(node)

def substituteCalls(expr, definitions, n = 3):
    while True:
        if (n <= 0):
            break
        n -= 1

        modified = False
        for fnName, varNames, fnExpr in definitions:
            transformer = CallTransformer(fnName, varNames, fnExpr)
            expr = transformer.visit(expr)
            modified = modified or transformer.modified

        if not modified:
            break

    return expr

The substituteCalls is recursive so we can inline recursive functions too. Also there is an explicit limit, because some definitions might be infinitely recursive (as fact below). There is a bit of ugly looking copying, but it is required to separate different subtrees.


And the example code:

if True:
    print "f1 first, unique variable names"
    ex = parseExpr("a+b+f1(f2(x, y), x)")
    ex = substituteCalls(ex, [
        ("f1", ["u", "v"], parseExpr("sin(u + v)")),
        ("f2", ["i", "j"], parseExpr("i + j ^ 2"))])
    print toSource(ex)
    print "---"

if True:
    print "f1 first"
    ex = parseExpr("a+b+f1(f2(x, y), x)")
    ex = substituteCalls(ex, [
        ("f1", ["x", "y"], parseExpr("sin(x + y)")),
        ("f2", ["x", "y"], parseExpr("x + y ^ 2"))])
    print toSource(ex)
    print "---"

if True:
    print "f2 first"
    ex = parseExpr("f1(f1(x, x), y)")
    ex = substituteCalls(ex, [
        ("f1", ["x", "y"], parseExpr("x + y"))])
    print toSource(ex)
    print "---"

if True:
    print "fact"
    ex = parseExpr("fact(n)")
    ex = substituteCalls(ex, [
        ("fact", ["n"], parseExpr("n if n == 0 else n * fact(n-1)"))])
    print toSource(ex)
    print "---"

Which prints out:

f1 first, unique variable names
expand call to f1(u, v) with arguments f2(x, y), x
substitute
  name u for f2(x, y)
  name v for x
  in sin((u + v))
expand call to f2(i, j) with arguments x, y
substitute
  name i for x
  name j for y
  in ((i + j) ^ 2)
((a + b) + sin((((x + y) ^ 2) + x)))
---
f1 first
expand call to f1(x, y) with arguments f2(x, y), x
substitute
  name y for x
  name x for f2(x, y)
  in sin((x + y))
expand call to f2(x, y) with arguments x, y
substitute
  name y for y
  name x for x
  in ((x + y) ^ 2)
((a + b) + sin((((x + y) ^ 2) + x)))
---
f2 first
expand call to f1(x, y) with arguments f1(x, x), y
expand call to f1(x, y) with arguments x, x
substitute
  name y for x
  name x for x
  in (x + y)
substitute
  name y for y
  name x for (x + x)
  in (x + x)
((x + x) + ((x + x) + x))
---
fact
expand call to fact(n) with arguments n
substitute
  name n for n
  in n if (n == 0) else (n * fact((n - 1)))
expand call to fact(n) with arguments (n - 1)
substitute
  name n for (n - 1)
  in n if (n == 0) else (n * fact((n - 1)))
expand call to fact(n) with arguments ((n - 1) - 1)
substitute
  name n for ((n - 1) - 1)
  in n if (n == 0) else (n * fact((n - 1)))
n if (n == 0) else (n * (n - 1) if ((n - 1) == 0) else ((n - 1) * ((n - 1) - 1) if (((n - 1) - 1) == 0) else (((n - 1) - 1) * fact((((n - 1) - 1) - 1)))))

Unfortunately codegen version in pypi is buggy. It doesn't parenthesise expressions properly, even AST says they should. I used jbremer/codegen (pip install git+git://github.com/jbremer/codegen). It adds unnecessary parenthesis too, but it's better than no at all. Thanks to @XavierCombelle for the tip.


The substitution gets trickier if you have anonymous functions, i.e lambda. Then you need to rename variables. You could try to search for lambda calculus with substitution or implementation. Yet I had bad luck to find any articles which use Python for the task.

like image 40
phadej Avatar answered Oct 20 '22 01:10

phadej