How can I calculate the loss without the weight decay in Keras?



I defined a convolutional layer and also use the L2 weight decay in Keras.

When I define the loss in the model.fit(), has all the weight decay loss been included in this loss? If the weight decay loss has been included in the total loss, how can I get the loss without this weight decay during the training?

I want to investigate the loss without the weight decay, while I want this weight decay to attend this training.

1 Answers

Yes, weight decay losses are included in the loss value printed on the screen.

The value you want to monitor is the total loss minus the sum of regularization losses.

  • The total loss is just model.total_loss .
  • The regularization losses are collected in the list model.losses.

The following lines can be found in the source code of model.compile():

# Add regularization penalties
# and other layer-specific losses.
for loss_tensor in self.losses:
    total_loss += loss_tensor

To get the loss without weight decay, you can reverse the above operations. I.e., the value to be monitored is model.total_loss - sum(model.losses).

Now, how to monitor this value is a bit tricky. Fortunately, the list of metrics used by a Keras model is not fixed until model.fit() is called. So you can append this value to the list, and it'll be printed on the screen during model fitting.

Here's a simple example:

input_tensor = Input(shape=(64, 64, 3))
hidden = Conv2D(32, 1, kernel_regularizer=l2(0.01))(input_tensor)
hidden = GlobalAveragePooling2D()(hidden)
out = Dense(1)(hidden)
model = Model(input_tensor, out)
model.compile(loss='mse', optimizer='adam')

loss_no_weight_decay = model.total_loss - sum(model.losses)

When you run model.fit(), something like this will be printed to the screen:

Epoch 1/1
100/100 [==================] - 0s - loss: 0.5764 - loss_no_weight_decay: 0.5178

You can also verify whether this value is correct by computing the L2 regularization manually:

conv_kernel = model.layers[1].get_weights()[0]
print(np.sum(0.01 * np.square(conv_kernel)))

In my case, the printed value is 0.0585, which is indeed the difference between loss and loss_no_weight_decay (with some rounding error).

