Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to transpile python Compare ast nodes to c?

Let's start by considering python3.8.5's grammar, in this case I'm interested to figure out how to transpile python Comparisons to c.

For the sake of simplicity, let's assume we're dealing with a very little python trivial subset and we just want to transpile trivial Compare expressions:

expr = Compare(expr left, cmpop* ops, expr* comparators)

If I'm not mistaken, in python an expression such as a<b<c is converted into something like a<b && b<c where b is only evaluated once... so I guess in c you should do something like bool v0=a<b; bool v1=v0<c in order to prevent b being evaluated more than once in case the first clause is true.

Unfortunately I don't know how to put that into code, so far this is what I've got:

import ast
import shutil
import textwrap
from subprocess import PIPE
from subprocess import Popen


class Visitor(ast.NodeVisitor):
    def visit(self, node):
        ret = super().visit(node)
        if ret is None:
            raise Exception("Unsupported node")
        return ret

    def visit_Expr(self, node):
        return f"{self.visit(node.value)};"

    def visit_Eq(self, node):
        return "=="

    def visit_Lt(self, node):
        return "<"

    def visit_LtE(self, node):
        return "<="

    def visit_Load(self, node):
        return "//load"

    def visit_Name(self, node):
        return f"{node.id}"

    def visit_Compare(self, node):
        left = self.visit(node.left)
        ops = [self.visit(x) for x in node.ops]
        comparators = [self.visit(x) for x in node.comparators]

        if len(ops) == 1 and len(comparators) == 1:
            return f"({left} {ops[0]} {comparators[0]})"
        else:
            lhs = ",".join([f"'{v}'" for v in ops])
            rhs = ",".join([f"{v}" for v in comparators])
            return f"cmp<{lhs}>({rhs})"

    def visit_Call(self, node):
        func = self.visit(node.func)
        args = [self.visit(x) for x in node.args]
        # keywords = [self.visit(x) for x in node.keywords]
        return f"{func}({','.join(args)})"

    def visit_Module(self, node):
        return f"{''.join([self.visit(x) for x in node.body])}"

    def visit_Num(self, node):
        return node.n


if __name__ == "__main__":
    out = Visitor().visit(
        ast.parse(
            textwrap.dedent(
                """
            1 == 1<3
            1 == (1<3)
            1 == (0 < foo(0 <= bar() < 3, baz())) < (4 < 5)
            foo(0 <= bar() < 3, baz())
        """
            )
        )
    )

    if shutil.which("clang-format"):
        cmd = "clang-format -style webkit -offset 0 -length {} -assume-filename None"
        p = Popen(
            cmd.format(len(out)), stdout=PIPE, stdin=PIPE, stderr=PIPE, shell=True
        )
        out = p.communicate(input=out.encode("utf-8"))[0].decode("utf-8")
        print(out)
    else:
        print(out)

As you can see, the output will be some sort of non compilable c output:

cmp<'==', '<'>(1, 3);
(1 == (1 < 3));
cmp<'==', '<'>((0 < foo(cmp<'<=', '<'>(bar(), 3), baz())), (4 < 5));
foo(cmp<'<=', '<'>(bar(), 3), baz());

Question, what'd be the algorithm (a python working example would be ideal here but just some general pseudocode that allowed me to improve the provided snippet would be also fine) that'd allowed me to convert python Compare expressions to c?

like image 873
BPL Avatar asked Jul 24 '20 08:07

BPL


2 Answers

The correct translation for:

if 1 == 2 < 3:

is:

int i1 = 1;
int i2 = 2;
int i3 = 3;
if(i1 == i2 && i2 < i3) {

(edit: this is still not correct because it doesn't short-circuit)

Alternatively, the last one doesn't have to be a temporary variable:

int i1 = 1;
int i2 = 2;
if(i1 == i2 && i2 < 3) {

Alternatively: (this version keeps the expressions being compared inside the comparison expression)

int i1;
if(1 == (i1 = 2) && i2 < 3) {

Your compilers needs to know that the values being compared are ints so that it can declare the temporary variables. And it also needs to choose temporary variable names which are different every time, so if you have two of these comparisons then it doesn't try to generate multiple variables with the same name.

You might realise that it's okay to evaluate the expression 2 more than once, so the compiler can generate this code:

if(1 == 2 && 2 < 3) {

but that's an optional extra.

Note that there could be more than one of these in the same expression:

if 1 < (2 if (6 < 7 < 8) else 3) < 4:

translates to something like this:

int i1 = 1;
    int i2 = 6;
    int i3 = 7;
    int i4 = 8;
int i5 = (i2 < i3 && i3 < i4 ? 2 : 3);
int i6 = 4;
if(i1 < i5 && i5 < i6) {

or:

int i1;
int i2;
if(1 < (i1 = (6 < (i2 = 7) && i2 < 8 ? 2 : 3)) && i1 < 4) {
//            ^^^^^^^^^^^^^^^^^^^^^^ inside
// ^^^^^^^^^^^                               ^^^^^^^^^^^ outside
like image 131
user253751 Avatar answered Dec 14 '22 00:12

user253751


An additional complication when converting Compare expressions is that you want to prevent sub-expressions that are used more than once after the split from being evaluated more than once, which is particularly important if there are side effects such as a function call.

One could take the sub-expressions and declare them as variables in advance to avoid multiple evaluations.

There is a clever method for converting Python comparison expressions to JavaScript from a guy named Alexander Schepanovski. He explains his whole solution in detail in his blog post: http://hackflow.com/blog/2015/04/12/metaprogramming-beyond-decency-part-2/.

Basically the same can be applied for a transpilation to C.

He determines pairs of adjacent operands. This is necessary to convert chained comparisons into separate comparisons in which the 'middle' operand is then copied and is the left operand of the splited second subcomparison.

A kind of symbol table could be used to associate the variables with sub-expressions. The naming of the variable can be done by a simple counter.

The variables can be output when visiting an expression node. To get an output in C for the expressions given as an example in the question, you can simply emit a printf.

For further simplification we could assume that the assumed small, trivial Python subset has only to deal with int expressions.

Python Code

I have taken your snippet and slightly modified it according to the above points so that it is a self-contained example that outputs compilable C code for your sample expressions.

import ast
import itertools
import textwrap


def pairwise(iterable):
    """s -> (s0,s1), (s1,s2), (s2, s3), ..."""
    a, b = itertools.tee(iterable)
    next(b, None)
    return zip(a, b)


class Visitor(ast.NodeVisitor):
    def __init__(self):
        self.varCounter = 0
        self.varTable = []

    def visit_Expr(self, node):
        code = self.visit(node.value)
        variables = '\n'.join(self.varTable)
        self.varTable = []
        return f'{variables}\nprintf("%d\\n", {code});\n'

    def visit_Eq(self, node):
        return "=="

    def visit_Lt(self, node):
        return '<'

    def visit_LtE(self, node):
        return '<='

    def visit_Gt(self, node):
        return ">"

    def visit_GtE(self, node):
        return ">="

    def visit_Name(self, node):
        return str(node.id)

    # see http://hackflow.com/blog/2015/04/12/metaprogramming-beyond-decency-part-2/
    def visit_Compare(self, node):
        ops = node.ops
        operands = [node.left] + node.comparators
        variables = []
        for o in operands:
            self.varCounter += 1
            num = self.varCounter
            op = self.visit(o)
            variables.append((num, op))
            self.varTable.append(f'int t{num} = {op};')

        pairs = pairwise(variables)  # adjacent pairs of operands

        return ' && '.join('%s(%s %s %s)' %
                             ('!' if isinstance(op, ast.NotIn) else '',
                              f't{l[0]}', self.visit(op), f't{r[0]}')
                             for op, (l, r) in zip(ops, pairs))

    def visit_Call(self, node):
        args = [self.visit(x) for x in node.args]
        return self.visit(node.func) + "(" + ", ".join(args) + ")"

    def visit_Num(self, node):
        return str(node.n)


def main():
    analyzer = Visitor()
    tree = ast.parse(
        textwrap.dedent(
            """
            1 == 1<3
            1 == (1<3)
            1 == (0 < foo(0 <= bar() < 3, baz())) < (4 < 5)
            foo(0 <= bar() < 3, baz())
            """
        )
    )

    # print(ast.dump(tree))

    for node in ast.iter_child_nodes(tree):
        c = analyzer.visit(node)
        print(c)


if __name__ == '__main__':
    main()

Test Run

When you run the Python program, the following is displayed in the debug console:

int t1 = 1;
int t2 = 1;
int t3 = 3;
printf("%d\n", (t1 == t2) && (t2 < t3));

int t4 = 1;
int t6 = 1;
int t7 = 3;
int t5 = (t6 < t7);
printf("%d\n", (t4 == t5));

int t8 = 1;
int t10 = 0;
int t12 = 0;
int t13 = bar();
int t14 = 3;
int t11 = foo((t12 <= t13) && (t13 < t14), baz());
int t9 = (t10 < t11);
int t16 = 4;
int t17 = 5;
int t15 = (t16 < t17);
printf("%d\n", (t8 == t9) && (t9 < t15));

int t18 = 0;
int t19 = bar();
int t20 = 3;
printf("%d\n", foo((t18 <= t19) && (t19 < t20), baz()));

Of course there is a way to simplify this further. For example, constant expressions do not need to be assigned to a variable. And of course there are many more details to consider. But this should be a starting point that outputs compilable C code for your example data.

like image 24
Stephan Schlecht Avatar answered Dec 14 '22 01:12

Stephan Schlecht