I was given this nn structure in the offical pytorch tutorial:
input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d -> view -> linear -> relu -> linear -> relu -> linear -> MSELoss -> loss
then an example of how to follow the grad backwards using built-in .grad_fn from Variable.
# Eg:
print(loss.grad_fn) # MSELoss
print(loss.grad_fn.next_functions[0][0]) # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0]) # ReLU
So I thought I can reach the grad object for Conv2d by pasting next_function[0][0] 9 times because of the given examples but I got the error tuple out of index. So how can I index these backprop objects correctly?
In the PyTorch CNN tutorial after running the following from the tutorial:
output = net(input)
target = torch.randn(10) # a dummy target, for example
target = target.view(1, -1) # make it the same shape as output
criterion = nn.MSELoss()
loss = criterion(output, target)
print(loss)
The following code snippet will print the full graph:
def print_graph(g, level=0):
if g == None: return
print('*'*level*4, g)
for subg in g.next_functions:
print_graph(subg[0], level+1)
print_graph(loss.grad_fn, 0)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With