In tensorflow, we can add a L1 or L2 regularizations in the sequential model. I couldn't find equivalent approach in pytorch. How can we add regularizations to weights in pytorch in the definition of the net:
class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
""" How to add a L1 regularization after a certain hidden layer?? """
""" OR How to add a L1 regularization after a certain hidden layer?? """
self.predict = torch.nn.Linear(n_hidden, n_output) # output layer
def forward(self, x):
x = F.relu(self.hidden(x)) # activation function for hidden layer
x = self.predict(x) # linear output
return x
net = Net(n_feature=1, n_hidden=10, n_output=1) # define the network
# print(net) # net architecture
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)
loss_func = torch.nn.MSELoss() # this is for regression mean squared loss
Generally L2 regularization is handled through the weight_decay argument for the optimizer in PyTorch (you can assign different arguments for different layers too). This mechanism, however, doesn't allow for L1 regularization without extending the existing optimizers or writing a custom optimizer.
According to the tensorflow docs they use a reduce_sum(abs(x)) penalty for L1 regularization and a reduce_sum(square(x)) penalty for L2 regularization. Probably the easiest way to achieve this is to just directly add these penalty terms to the loss function used for gradient computation during training.
# set l1_weight and l2_weight to non-zero values to enable penalties
# inside the training loop (given input x and target y)
...
pred = net(x)
loss = loss_func(pred, y)
# compute penalty only for net.hidden parameters
l1_penalty = l1_weight * sum([p.abs().sum() for p in net.hidden.parameters()])
l2_penalty = l2_weight * sum([(p**2).sum() for p in net.hidden.parameters()])
loss_with_penalty = loss + l1_penalty + l2_penalty
optimizer.zero_grad()
loss_with_penalty.backward()
optimizer.step()
# The pre-penalty loss is the one we ultimately care about
print('loss:', loss.item())
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With