Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch get all layers of model

Tags:

python

pytorch

What's the easiest way to take a pytorch model and get a list of all the layers without any nn.Sequence groupings? For example, a better way to do this?

import pretrainedmodels

def unwrap_model(model):
    for i in children(model):
        if isinstance(i, nn.Sequential): unwrap_model(i)
        else: l.append(i)

model = pretrainedmodels.__dict__['xception'](num_classes=1000, pretrained='imagenet')
l = []
unwrap_model(model)            
            
print(l)
    
like image 749
Austin Avatar asked Feb 23 '19 22:02

Austin


2 Answers

You can iterate over all modules of a model (including those inside each Sequential) with the modules() method. Here's a simple example:

>>> model = nn.Sequential(nn.Linear(2, 2), 
                          nn.ReLU(),
                          nn.Sequential(nn.Linear(2, 1),
                          nn.Sigmoid()))

>>> l = [module for module in model.modules() if not isinstance(module, nn.Sequential)]

>>> l

[Linear(in_features=2, out_features=2, bias=True),
 ReLU(),
 Linear(in_features=2, out_features=1, bias=True),
 Sigmoid()]
like image 100
Andreas K. Avatar answered Sep 21 '22 09:09

Andreas K.


If you want a nested dictionary with names as keys and modules as values, e.g.:

{'conv1': Conv2d(...),
 'bn1': BatchNorm2d(...),
 'block1':{
    'group1':{
        'conv1': Conv2d(...),
        'bn1': BatchNorm2d(...),
        'conv2': Conv2d(...),
        'bn2': BatchNorm2d(...),
    },
    'group2':{ ...
    }, ...
}

You can combine the answers of Kees and Mayukh Deb to get:

def nested_children(m: torch.nn.Module):
    children = dict(m.named_children())
    output = {}
    if children == {}:
        # if module has no children; m is last child! :O
        return m
    else:
        # look for children from children... to the last child!
        for name, child in children.items():
            try:
                output[name] = nested_children(child)
            except TypeError:
                output[name] = nested_children(child)
    return output
like image 24
Jetze Schuurmans Avatar answered Sep 17 '22 09:09

Jetze Schuurmans