When I use a pre-defined module in PyTorch, I can typically access its weights fairly easily. However, how do I access them if I wrapped the module in nn.Sequential()
first? r.g:
class My_Model_1(nn.Module):
def __init__(self,D_in,D_out):
super(My_Model_1, self).__init__()
self.layer = nn.Linear(D_in,D_out)
def forward(self,x):
out = self.layer(x)
return out
class My_Model_2(nn.Module):
def __init__(self,D_in,D_out):
super(My_Model_2, self).__init__()
self.layer = nn.Sequential(nn.Linear(D_in,D_out))
def forward(self,x):
out = self.layer(x)
return out
model_1 = My_Model_1(10,10)
print(model_1.layer.weight)
model_2 = My_Model_2(10,10)
How do I print the weights now?
model_2.layer.0.weight
doesn't work.
Secondly, nn. Sequential runs the three layers at once, this is, it takes the input, run layer1, take output1 and feed layer2 with it, take output2 and feed layer3 giving as result output3. So nn. Sequential is a construction which is used when you want to run certain layers sequentially.
nn. Sequential is faster than not using it.
The objective of nn. Sequential is to quickly implement sequential modules such that you are not required to write the forward definition, it being implicitly known because the layers are sequentially called on the outputs. In a more complicated module though, you might need to use multiple sequential submodules.
PyTorch doesn't have a utility function (at least at the moment!) to count the number of model parameters, but there is a property of the model class that you can use to get the model parameters. model. parameters(): PyTorch modules have a a method called parameters() which returns an iterator over all the parameters.
An easy way to access the weights is to use the state_dict()
of your model.
This should work in your case:
for k, v in model_2.state_dict().iteritems():
print("Layer {}".format(k))
print(v)
Another option is to get the modules()
iterator. If you know beforehand the type of your layers this should also work:
for layer in model_2.modules():
if isinstance(layer, nn.Linear):
print(layer.weight)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With