Not sure of the terminology here, but this would be difference between eq?
and equal?
in scheme, or the difference between ==
and strncmp
with C strings; where in each case the first would return false for two different strings that actually have the same content and the second would return true.
I'm looking for the latter operation, for Python's ASTs.
Right now, I'm doing this:
import ast
def AST_eq(a, b):
return ast.dump(a) == ast.dump(b)
which apparently works but feels like a disaster waiting to happen. Anyone know of a better way?
Edit: unfortunately, when I go to compare the two ASTs' __dict__
's, that comparison defaults to using the individual elements' __eq__
methods. ASTs are implemented as trees of other ASTs, and their __eq__
apparently checks for reference identity. So neither straight ==
nor the solution in Thomas's link work. (Besides which, I also don't want to subclass every AST node type to insert this custom __eq__
.)
I ran into the same problem. I tried to go this way: first dumb down AST to some easier representation (a tree of dicts):
def simplify(node):
if isinstance(node, ast.AST):
res = vars(node).copy()
for k in 'lineno', 'col_offset', 'ctx':
res.pop(k, None)
for k, v in res.iteritems():
res[k] = simplify(v)
res['__type__'] = type(node).__name__
return res
elif isinstance(node, list):
return map(simplify, node)
else:
return node
and then you can just compare these representations:
data = open("/usr/lib/python2.7/ast.py").read()
a1 = ast.parse(data)
a2 = ast.parse(data)
print simplify(a1) == simplify(a2)
will give you True
EDIT
Just understood that there's no need to create a dict, so you can do just:
def compare_ast(node1, node2):
if type(node1) is not type(node2):
return False
if isinstance(node1, ast.AST):
for k, v in vars(node1).iteritems():
if k in ('lineno', 'col_offset', 'ctx'):
continue
if not compare_ast(v, getattr(node2, k)):
return False
return True
elif isinstance(node1, list):
return all(itertools.starmap(compare_ast, itertools.izip(node1, node2)))
else:
return node1 == node2
I modified @Yorik.sar's answer for Python 3.9+:
from itertools import zip_longest
from typing import Union
def compare_ast(node1: Union[ast.expr, list[ast.expr]], node2: Union[ast.expr, list[ast.expr]]) -> bool:
if type(node1) is not type(node2):
return False
if isinstance(node1, ast.AST):
for k, v in vars(node1).items():
if k in {"lineno", "end_lineno", "col_offset", "end_col_offset", "ctx"}:
continue
if not compare_ast(v, getattr(node2, k)):
return False
return True
elif isinstance(node1, list) and isinstance(node2, list):
return all(compare_ast(n1, n2) for n1, n2 in zip_longest(node1, node2))
else:
return node1 == node2
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With