Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Reset parameters of a neural network in pytorch

I have a neural network with the following structure:

class myNetwork(nn.Module):
    def __init__(self):
        super(myNetwork, self).__init__()
        self.bigru = nn.GRU(input_size=2, hidden_size=100, batch_first=True, bidirectional=True)
        self.fc1 = nn.Linear(200, 32)
        torch.nn.init.xavier_uniform_(self.fc1.weight)
        self.fc2 = nn.Linear(32, 2)
        torch.nn.init.xavier_uniform_(self.fc2.weight)

I need to reinstate the model to an unlearned state by resetting the parameters of the neural network. I can do so for nn.Linear layers by using the method below:

def reset_weights(self):
    torch.nn.init.xavier_uniform_(self.fc1.weight)
    torch.nn.init.xavier_uniform_(self.fc2.weight)

But, to reset the weight of the nn.GRU layer, I could not find any such snippet.

My question is how does one reset the nn.GRU layer? Any other way of resetting the network is also fine. Any help is appreciated.

like image 253
learner Avatar asked Aug 28 '20 05:08

learner


1 Answers

You can use reset_parameters method on the layer. As given here

for layer in model.children():
   if hasattr(layer, 'reset_parameters'):
       layer.reset_parameters()

Or Another way would be saving the model first and then reload the module state. Using torch.save and torch.load see docs for more Or Saving and Loading Models

like image 103
Dishin H Goyani Avatar answered Oct 07 '22 12:10

Dishin H Goyani