Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Getting all the nodes from Python AST that correspond to a particular variable with a given name

Consider the code below:

1 | x = 20
2 | 
3 | def f():
4 |     x = 0
5 |     for x in range(10):
6 |         x += 10
7 |     return x
8 | f()
9 |
10| for x in range(10):
11|     pass
12| x += 1
13| print(x)

The value of x after execution of the code above is 10. Now, how can I get all the nodes with class Name whose ids are x and refer to the x that's being used in lines 1, 10, 12 and 13?

In other words, the x inside of f is different from the rest of the xs. Is it possible to get their AST nodes, having only the script and script's AST while not executing it?

like image 450
Atonal Avatar asked Apr 02 '17 09:04

Atonal


1 Answers

When walking the AST tree, track the context; start with a global context, then as you encounter FunctionDef or ClassDef or Lambda nodes, record that context as a stack (pop the stack again when exiting the relevant node).

You can then simply only look at Name nodes in the global context. You can track global identifiers too (I'd use a set per stack level).

Using a NodeVisitor subclass:

import ast

class GlobalUseCollector(ast.NodeVisitor):
    def __init__(self, name):
        self.name = name
        # track context name and set of names marked as `global`
        self.context = [('global', ())]

    def visit_FunctionDef(self, node):
        self.context.append(('function', set()))
        self.generic_visit(node)
        self.context.pop()

    # treat coroutines the same way
    visit_AsyncFunctionDef = visit_FunctionDef

    def visit_ClassDef(self, node):
        self.context.append(('class', ()))
        self.generic_visit(node)
        self.context.pop()

    def visit_Lambda(self, node):
        # lambdas are just functions, albeit with no statements, so no assignments
        self.context.append(('function', ()))
        self.generic_visit(node)
        self.context.pop()

    def visit_Global(self, node):
        assert self.context[-1][0] == 'function'
        self.context[-1][1].update(node.names)

    def visit_Name(self, node):
        ctx, g = self.context[-1]
        if node.id == self.name and (ctx == 'global' or node.id in g):
            print('{} used at line {}'.format(node.id, node.lineno))

Demo (given the AST tree for your sample code in t):

>>> GlobalUseCollector('x').visit(t)
x used at line 1
x used at line 10
x used at line 12
x used at line 13

And using global x in a function:

>>> u = ast.parse('''\
... x = 20
...
... def g():
...     global x
...     x = 0
...     for x in range(10):
...         x += 10
...     return x
...
... g()
... for x in range(10):
...     pass
... x += 1
... print(x)
... ''')
>>> GlobalUseCollector('x').visit(u)
x used at line 1
x used at line 5
x used at line 6
x used at line 7
x used at line 8
x used at line 11
x used at line 13
x used at line 14
like image 52
Martijn Pieters Avatar answered Oct 19 '22 23:10

Martijn Pieters