Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to get all the tensors in a graph?

I would like to access all the tensors instances of a graph. For example, I can check if a tensor is detached or I can check the size. It can be done in tensorflow.

I don't want visualization of the graph.

like image 324
Tengerye Avatar asked Dec 21 '18 02:12

Tengerye


1 Answers

You can get access to the entirety of the computation graph at runtime. To do so, you can use hooks. These are functions plugged onto nn.Modules both for inference and when backpropagating.

At inference, you can hook a callback function with register_forward_hook. Similarly for backpropagation, you can use register_full_backward_hook.
Note: as of PyTorch version 1.8.0 register_backward_hook has been deprecated.

With these two functions, you will basically have access to any tensor on the computation graph. It's entirely up to you whether you want to print all tensors, print the shapes, or even insert breakpoints to investigate.

Here is a possible implementation:

def forward_hook(module, input, output):
    # ...

Argument input is passed by PyTorch as a tuple and will contain all arguments passed to the forward function of the hooked module.

def backward_hook(module, grad_input, grad_output):
    # ...

For the backward hook, both grad_input and grad_output will be tuples and will have varying shapes depending on your model's layers.

Then you can hook these callbacks on any existing nn.Module. For example, you could loop over all child modules from your model:

for module in model.children():
    module.register_forward_hook(forward_hook)
    module.register_full_backward_hook(backward_hook)

To get the names of the modules, you can wrap the hook to enclose the name and loop on your model's named_modules:

def forward_hook(name):
    def hook(module, x, y):
        print(f'{name}: {[tuple(i.shape) for i in x]} -> {list(y.shape)}')
    return hook

for name, module in model.named_children():
    module.register_forward_hook(forward_hook(name))

Which could print the following on inference:

fc1: [(1, 100)] -> (1, 10)
fc2: [(1, 10)] -> (1, 5)
fc3: [(1, 5)] -> (1, 1)

As for the model's parameter, you can easily access the parameters for a given module in both hooks by calling module.parameters. This will return a generator.

like image 134
Ivan Avatar answered Sep 23 '22 14:09

Ivan