Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do you use next_functions[0][0] on grad_fn correctly in pytorch?

Tags:

pytorch

I was given this nn structure in the offical pytorch tutorial:

input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d -> view -> linear -> relu -> linear -> relu -> linear -> MSELoss -> loss

then an example of how to follow the grad backwards using built-in .grad_fn from Variable.

# Eg: 
print(loss.grad_fn)  # MSELoss
print(loss.grad_fn.next_functions[0][0])  # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0])  # ReLU

So I thought I can reach the grad object for Conv2d by pasting next_function[0][0] 9 times because of the given examples but I got the error tuple out of index. So how can I index these backprop objects correctly?

like image 518
Inkplay_ Avatar asked Mar 06 '23 14:03

Inkplay_


1 Answers

In the PyTorch CNN tutorial after running the following from the tutorial:

output = net(input)
target = torch.randn(10)  # a dummy target, for example
target = target.view(1, -1)  # make it the same shape as output
criterion = nn.MSELoss()

loss = criterion(output, target)
print(loss)

The following code snippet will print the full graph:

def print_graph(g, level=0):
    if g == None: return
    print('*'*level*4, g)
    for subg in g.next_functions:
        print_graph(subg[0], level+1)

print_graph(loss.grad_fn, 0)

like image 144
Randy Dueck Avatar answered Apr 28 '23 06:04

Randy Dueck