Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

elegant way to test python ASTs for equality (not reference or object identity)

Not sure of the terminology here, but this would be difference between eq? and equal? in scheme, or the difference between == and strncmp with C strings; where in each case the first would return false for two different strings that actually have the same content and the second would return true.

I'm looking for the latter operation, for Python's ASTs.

Right now, I'm doing this:

import ast
def AST_eq(a, b):
    return ast.dump(a) == ast.dump(b)

which apparently works but feels like a disaster waiting to happen. Anyone know of a better way?

Edit: unfortunately, when I go to compare the two ASTs' __dict__'s, that comparison defaults to using the individual elements' __eq__ methods. ASTs are implemented as trees of other ASTs, and their __eq__ apparently checks for reference identity. So neither straight == nor the solution in Thomas's link work. (Besides which, I also don't want to subclass every AST node type to insert this custom __eq__.)

like image 877
Wang Avatar asked Jul 22 '10 20:07

Wang


2 Answers

I ran into the same problem. I tried to go this way: first dumb down AST to some easier representation (a tree of dicts):

def simplify(node):
    if isinstance(node, ast.AST):
        res = vars(node).copy()
        for k in 'lineno', 'col_offset', 'ctx':
            res.pop(k, None)
        for k, v in res.iteritems():
            res[k] = simplify(v)
        res['__type__'] = type(node).__name__
        return res
    elif isinstance(node, list):
        return map(simplify, node)
    else:
        return node

and then you can just compare these representations:

data = open("/usr/lib/python2.7/ast.py").read()
a1 = ast.parse(data)
a2 = ast.parse(data)
print simplify(a1) == simplify(a2)

will give you True

EDIT

Just understood that there's no need to create a dict, so you can do just:

def compare_ast(node1, node2):
    if type(node1) is not type(node2):
        return False
    if isinstance(node1, ast.AST):
        for k, v in vars(node1).iteritems():
            if k in ('lineno', 'col_offset', 'ctx'):
                continue
            if not compare_ast(v, getattr(node2, k)):
                return False
        return True
    elif isinstance(node1, list):
        return all(itertools.starmap(compare_ast, itertools.izip(node1, node2)))
    else:
        return node1 == node2
like image 178
Yorik.sar Avatar answered Nov 20 '22 09:11

Yorik.sar


I modified @Yorik.sar's answer for Python 3.9+:

from itertools import zip_longest
from typing import Union


def compare_ast(node1: Union[ast.expr, list[ast.expr]], node2: Union[ast.expr, list[ast.expr]]) -> bool:
    if type(node1) is not type(node2):
        return False

    if isinstance(node1, ast.AST):
        for k, v in vars(node1).items():
            if k in {"lineno", "end_lineno", "col_offset", "end_col_offset", "ctx"}:
                continue
            if not compare_ast(v, getattr(node2, k)):
                return False
        return True

    elif isinstance(node1, list) and isinstance(node2, list):
        return all(compare_ast(n1, n2) for n1, n2 in zip_longest(node1, node2))
    else:
        return node1 == node2
like image 4
Seanny123 Avatar answered Nov 20 '22 11:11

Seanny123