Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Custom weight initialization in PyTorch

Tags:

python

pytorch

What would be the right way to implement a custom weight initialization method in PyTorch?

I believe I can't directly add any method to 'torch.nn.init` but wish to initialize my model's weights with my own proprietary method.

like image 617
Omry Sendik Avatar asked Jul 01 '18 18:07

Omry Sendik


People also ask

What is weight initialization in PyTorch?

The aim of weight initialization is to prevent the model from exploding or vanishing during the forward pass through a deep neural network. If occurs, loss gradients will either be too large or too small to flow backward and the network will take longer to converge.

How are weights initialized by default PyTorch?

Weight Initializations with PyTorch By default, PyTorch uses Lecun initialization, so nothing new has to be done here compared to using Normal, Xavier or Kaiming initialization.

What are the methods of initialization of weights?

Step-1: Initialization of Neural Network: Initialize weights and biases. Step-2: Forward propagation: Using the given input X, weights W, and biases b, for every layer we compute a linear combination of inputs and weights (Z)and then apply activation function to linear combination (A).


2 Answers

You can define a method to initialize the weights according to each layer:

def weights_init(m):
    classname = m.__class__.__name__

    if classname.find('Conv2d') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

And then just apply it to your network:

model = create_your_model()
model.apply(weights_init)
like image 114
Manuel Lagunas Avatar answered Oct 06 '22 05:10

Manuel Lagunas


See https://discuss.pytorch.org/t/how-to-initialize-weights-bias-of-rnn-lstm-gru/2879/2 for reference.

You can do

weight_dict = net.state_dict()
new_weight_dict = {}
for param_key in state_dict:
     # custom initialization in new_weight_dict,
     # You can initialize partially i.e only some of the variables and let others stay as it is
weight_dict.update(new_weight_dict)
net.load_state_dict(new_weight_dict)
like image 39
Umang Gupta Avatar answered Oct 06 '22 06:10

Umang Gupta