I would like to access all the tensors instances of a graph. For example, I can check if a tensor is detached or I can check the size. It can be done in tensorflow.
I don't want visualization of the graph.
You can get access to the entirety of the computation graph at runtime. To do so, you can use hooks. These are functions plugged onto nn.Module
s both for inference and when backpropagating.
At inference, you can hook a callback function with register_forward_hook
. Similarly for backpropagation, you can use register_full_backward_hook
.
Note: as of PyTorch version 1.8.0 register_backward_hook
has been deprecated.
With these two functions, you will basically have access to any tensor on the computation graph. It's entirely up to you whether you want to print all tensors, print the shapes, or even insert breakpoints to investigate.
Here is a possible implementation:
def forward_hook(module, input, output):
# ...
Argument input
is passed by PyTorch as a tuple and will contain all arguments passed to the forward function of the hooked module.
def backward_hook(module, grad_input, grad_output):
# ...
For the backward hook, both grad_input
and grad_output
will be tuples and will have varying shapes depending on your model's layers.
Then you can hook these callbacks on any existing nn.Module
. For example, you could loop over all child modules from your model:
for module in model.children():
module.register_forward_hook(forward_hook)
module.register_full_backward_hook(backward_hook)
To get the names of the modules, you can wrap the hook to enclose the name and loop on your model's named_modules
:
def forward_hook(name):
def hook(module, x, y):
print(f'{name}: {[tuple(i.shape) for i in x]} -> {list(y.shape)}')
return hook
for name, module in model.named_children():
module.register_forward_hook(forward_hook(name))
Which could print the following on inference:
fc1: [(1, 100)] -> (1, 10)
fc2: [(1, 10)] -> (1, 5)
fc3: [(1, 5)] -> (1, 1)
As for the model's parameter, you can easily access the parameters for a given module in both hooks by calling module.parameters
. This will return a generator.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With