What's the easiest way to take a pytorch model and get a list of all the layers without any nn.Sequence
groupings? For example, a better way to do this?
import pretrainedmodels
def unwrap_model(model):
for i in children(model):
if isinstance(i, nn.Sequential): unwrap_model(i)
else: l.append(i)
model = pretrainedmodels.__dict__['xception'](num_classes=1000, pretrained='imagenet')
l = []
unwrap_model(model)
print(l)
You can iterate over all modules of a model (including those inside each Sequential
) with the modules()
method. Here's a simple example:
>>> model = nn.Sequential(nn.Linear(2, 2),
nn.ReLU(),
nn.Sequential(nn.Linear(2, 1),
nn.Sigmoid()))
>>> l = [module for module in model.modules() if not isinstance(module, nn.Sequential)]
>>> l
[Linear(in_features=2, out_features=2, bias=True),
ReLU(),
Linear(in_features=2, out_features=1, bias=True),
Sigmoid()]
If you want a nested dictionary with names as keys and modules as values, e.g.:
{'conv1': Conv2d(...),
'bn1': BatchNorm2d(...),
'block1':{
'group1':{
'conv1': Conv2d(...),
'bn1': BatchNorm2d(...),
'conv2': Conv2d(...),
'bn2': BatchNorm2d(...),
},
'group2':{ ...
}, ...
}
You can combine the answers of Kees and Mayukh Deb to get:
def nested_children(m: torch.nn.Module):
children = dict(m.named_children())
output = {}
if children == {}:
# if module has no children; m is last child! :O
return m
else:
# look for children from children... to the last child!
for name, child in children.items():
try:
output[name] = nested_children(child)
except TypeError:
output[name] = nested_children(child)
return output
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With