Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do I visualize a net in Pytorch?

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.models as models
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.models.vgg import model_urls
from torchviz import make_dot

batch_size = 3
learning_rate =0.0002
epoch = 50

resnet = models.resnet50(pretrained=True)
print resnet
make_dot(resnet)

I want to visualize resnet from the pytorch models. How can I do it? I tried to use torchviz but it gives an error:

'ResNet' object has no attribute 'grad_fn'
like image 985
raaj Avatar asked Sep 23 '18 18:09

raaj


People also ask

How do you visualize a model in PyTorch?

For all of them, you need to have dummy input that can pass through the model's forward() method. A simple way to get this input is to retrieve a batch from your Dataloader, like this: batch = next(iter(dataloader_train)) yhat = model(batch. text) # Give dummy batch to forward().

Can you use TensorBoard with PyTorch?

Note: Having TensorFlow installed is not a prerequisite to running TensorBoard, although it is a product of the TensorFlow ecosystem, TensorBoard by itself can be used with PyTorch.

How do you visualize filters on CNN PyTorch?

A kernel's depth matches the number of channels in the input to the convolutional layer. For example, input image shape (CxHxW): (3, 128, 128) and now we apply a Conv Layer with number of output channels 128 and kernel size 3. The number of kernels in the filter is the same as the number of output channels.


3 Answers

Here are three different graph visualizations using different tools.

In order to generate example visualizations, I'll use a simple RNN to perform sentiment analysis taken from an online tutorial:

class RNN(nn.Module):

    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):

        super().__init__()
        self.embedding  = nn.Embedding(input_dim, embedding_dim)
        self.rnn        = nn.RNN(embedding_dim, hidden_dim)
        self.fc         = nn.Linear(hidden_dim, output_dim)

    def forward(self, text):

        embedding       = self.embedding(text)
        output, hidden  = self.rnn(embedding)

        return self.fc(hidden.squeeze(0))

Here is the output if you print() the model.

RNN(
  (embedding): Embedding(25002, 100)
  (rnn): RNN(100, 256)
  (fc): Linear(in_features=256, out_features=1, bias=True)
)

Below are the results from three different visualization tools.

For all of them, you need to have dummy input that can pass through the model's forward() method. A simple way to get this input is to retrieve a batch from your Dataloader, like this:

batch = next(iter(dataloader_train))
yhat = model(batch.text) # Give dummy batch to forward().

Torchviz

https://github.com/szagoruyko/pytorchviz

I believe this tool generates its graph using the backwards pass, so all the boxes use the PyTorch components for back-propagation.

from torchviz import make_dot

make_dot(yhat, params=dict(list(model.named_parameters()))).render("rnn_torchviz", format="png")

This tool produces the following output file:

torchviz output

This is the only output that clearly mentions the three layers in my model, embedding, rnn, and fc. The operator names are taken from the backward pass, so some of them are difficult to understand.

HiddenLayer

https://github.com/waleedka/hiddenlayer

This tool uses the forward pass, I believe.

import hiddenlayer as hl

transforms = [ hl.transforms.Prune('Constant') ] # Removes Constant nodes from graph.

graph = hl.build_graph(model, batch.text, transforms=transforms)
graph.theme = hl.graph.THEMES['blue'].copy()
graph.save('rnn_hiddenlayer', format='png')

Here is the output. I like the shade of blue.

hiddenlayer output

I find that the output has too much detail and obfuscates my architecture. For example, why is unsqueeze mentioned so many times?

Netron

https://github.com/lutzroeder/netron

This tool is a desktop application for Mac, Windows, and Linux. It relies on the model being first exported into ONNX format. The application then reads the ONNX file and renders it. There is then an option to export the model to an image file.

input_names = ['Sentence']
output_names = ['yhat']
torch.onnx.export(model, batch.text, 'rnn.onnx', input_names=input_names, output_names=output_names)

Here's what the model looks like in the application. I think this tool is pretty slick: you can zoom and pan around, and you can drill into the layers and operators. The only negative I've found is that it only does vertical layouts.

Netron screenshot

like image 117
stackoverflowuser2010 Avatar answered Oct 09 '22 19:10

stackoverflowuser2010


The make_dot expects a variable (i.e., tensor with grad_fn), not the model itself.
try:

x = torch.zeros(1, 3, 224, 224, dtype=torch.float, requires_grad=False)
out = resnet(x)
make_dot(out)  # plot graph of variable, not of a nn.Module
like image 37
Shai Avatar answered Oct 09 '22 20:10

Shai


You can have a look at PyTorchViz (https://github.com/szagoruyko/pytorchviz), "A small package to create visualizations of PyTorch execution graphs and traces."

Example PyTorchViz visualization

like image 16
David J. Avatar answered Oct 09 '22 20:10

David J.