Logo Questions Linux Laravel Mysql Ubuntu Git Menu

How to get the output from a specific layer from a PyTorch model?




How to extract the features from a specific layer from a pre-trained PyTorch model (such as ResNet or VGG), without doing a forward pass again?

like image 369
bryant1410 Avatar asked Oct 13 '18 18:10


People also ask

How do you get the intermediate layer output in PyTorch?

To extract activations from intermediate layers, we'll need to create a forward hook in our neural network for the layers we're interested in and use inference to store the relevant outputs.

What is model parameters () PyTorch?

The PyTorch parameter is a layer made up of nn or a module. A parameter that is assigned as an attribute inside a custom model is registered as a model parameter and is thus returned by the caller model. parameters(). We can say that a Parameter is a wrapper over Variables that are formed.

What is PyTorch State_dict?

A state_dict is an integral entity if you are interested in saving or loading models from PyTorch. Because state_dict objects are Python dictionaries, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to PyTorch models and optimizers.

1 Answers

New answer

Edit: there's a new feature in torchvision v0.11.0 that allows extracting features.

For example, if you wanna extract features from the layer layer4.2.relu_2, you can do like:

import torch
from torchvision.models import resnet50
from torchvision.models.feature_extraction import create_feature_extractor

x = torch.rand(1, 3, 224, 224)

model = resnet50()

return_nodes = {
    "layer4.2.relu_2": "layer4"
model2 = create_feature_extractor(model, return_nodes=return_nodes)
intermediate_outputs = model2(x)

Old answer

You can register a forward hook on the specific layer you want. Something like:

def some_specific_layer_hook(module, input_, output):
    pass  # the value is in 'output'


For example, to obtain the res5c output in ResNet, you may want to use a nonlocal variable (or global in Python 2):

res5c_output = None

def res5c_hook(module, input_, output):
    nonlocal res5c_output
    res5c_output = output


# Then, use `res5c_output`.
like image 151
bryant1410 Avatar answered Sep 20 '22 10:09
