Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch: access weights of a specific module in nn.Sequential()

Tags:

python

pytorch

When I use a pre-defined module in PyTorch, I can typically access its weights fairly easily. However, how do I access them if I wrapped the module in nn.Sequential() first? r.g:

class My_Model_1(nn.Module):
    def __init__(self,D_in,D_out):
        super(My_Model_1, self).__init__()
        self.layer = nn.Linear(D_in,D_out)
    def forward(self,x):
        out = self.layer(x)
        return out

class My_Model_2(nn.Module):
    def __init__(self,D_in,D_out):
        super(My_Model_2, self).__init__()
        self.layer = nn.Sequential(nn.Linear(D_in,D_out))
    def forward(self,x):
        out = self.layer(x)
        return out

model_1 = My_Model_1(10,10)
print(model_1.layer.weight)
model_2 = My_Model_2(10,10)

How do I print the weights now? model_2.layer.0.weight doesn't work.

like image 813
mbpaulus Avatar asked May 31 '17 12:05

mbpaulus


People also ask

What does nn sequential do in PyTorch?

Secondly, nn. Sequential runs the three layers at once, this is, it takes the input, run layer1, take output1 and feed layer2 with it, take output2 and feed layer3 giving as result output3. So nn. Sequential is a construction which is used when you want to run certain layers sequentially.

Is nn sequential faster?

nn. Sequential is faster than not using it.

What is the use of nn sequential?

The objective of nn. Sequential is to quickly implement sequential modules such that you are not required to write the forward definition, it being implicitly known because the layers are sequentially called on the outputs. In a more complicated module though, you might need to use multiple sequential submodules.

How do you find the parameters of a model PyTorch?

PyTorch doesn't have a utility function (at least at the moment!) to count the number of model parameters, but there is a property of the model class that you can use to get the model parameters. model. parameters(): PyTorch modules have a a method called parameters() which returns an iterator over all the parameters.


1 Answers

An easy way to access the weights is to use the state_dict() of your model.

This should work in your case:

for k, v in model_2.state_dict().iteritems():
    print("Layer {}".format(k))
    print(v)

Another option is to get the modules() iterator. If you know beforehand the type of your layers this should also work:

for layer in model_2.modules():
   if isinstance(layer, nn.Linear):
        print(layer.weight)
like image 84
aesadde Avatar answered Sep 25 '22 23:09

aesadde