I'm analyzing the AST generated by python code for "fun and profit", and I would like to have something more graphical than "ast.dump" to actually see the AST generated.
In theory is already a tree, so it shouldn't be too hard to create a graph, but I don't understand how I could do it.
ast.walk seems to walk with a BFS strategy, and the visitX methods I can't really see the parent or I don't seem to find a way to create a graph...
It seems like the only way is to write my own DFS walk function, is does it make sense?
Fantastic, it works and it's really simple
class AstGraphGenerator(object):
def __init__(self):
self.graph = defaultdict(lambda: [])
def __str__(self):
return str(self.graph)
def visit(self, node):
"""Visit a node."""
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
return visitor(node)
def generic_visit(self, node):
"""Called if no explicit visitor function exists for a node."""
for _, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
self.visit(item)
elif isinstance(value, ast.AST):
self.graph[type(node)].append(type(value))
self.visit(value)
So it's the same as a normal NodeVisitor, but I have a defaultdict where I add the type of the node for each son. Then I pass this dictionary to pygraphviz.AGraph and I get my nice result.
The only problem is that the type doesn't say much, but on the other hand using ast.dump() is way too verbose.
Best thing would be to get the actual source code for each node, is that possible?
EDIT: now it's much better, I pass in the constructor also the source code and I try to get the code line if possible, otherwise just print out the type.
class AstGraphGenerator(object):
def __init__(self, source):
self.graph = defaultdict(lambda: [])
self.source = source # lines of the source code
def __str__(self):
return str(self.graph)
def _getid(self, node):
try:
lineno = node.lineno - 1
return "%s: %s" % (type(node), self.source[lineno].strip())
except AttributeError:
return type(node)
def visit(self, node):
"""Visit a node."""
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
return visitor(node)
def generic_visit(self, node):
"""Called if no explicit visitor function exists for a node."""
for _, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
self.visit(item)
elif isinstance(value, ast.AST):
node_source = self._getid(node)
value_source = self._getid(value)
self.graph[node_source].append(value_source)
# self.graph[type(node)].append(type(value))
self.visit(value)
If you look at ast.NodeVisitor, it's a fairly trivial class. You can either subclass it or just reimplement its walking strategy to whatever you need. For instance, keeping references to the parent when nodes are visited is very simple to implement this way, just add a visit
method that also accepts the parent as an argument, and pass that from your own generic_visit
.
P.S. By the way, it appears that NodeVisitor.generic_visit
implements DFS, so all you have to do is add the parent node passing.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With