Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Center Loss in Keras

Tags:

keras

I want to implement center Loss explained in [http://ydwen.github.io/papers/WenECCV16.pdf] in Keras

I started to create a network with 2 outputs such as :

inputs = Input(shape=(100,100,3))
...
fc = Dense(100)(#previousLayer#)
softmax = Softmax(fc)
model = Model(input, output=[softmax, fc])
model.compile(optimizer='sgd', 
              loss=['categorical_crossentropy', 'center_loss'],
              metrics=['accuracy'], loss_weights=[1., 0.2])

First of all, doing like this, is it the good way to proceed?

Secondly, I don't know how to implement the center_loss in keras. Center_loss looks like mean square error but instead of comparing values to fixed labels, it compares values to data updated at each iteration.

Thank you for your help

like image 814
slegall56 Avatar asked Oct 21 '16 10:10

slegall56


1 Answers

For me, you can implement this layer following the steps:

  1. write a custom layer ComputeCenter that

    • takes two inputs: i). the groudtruth labels y_true (not one-hot encoded, but just integers) and ii). predicted membership y_pred

    • contains a look-up table W of size num_classes x num_feats array as trainable weights (refer to BatchNormalization Layer), and W[j] is the place holder for the moving average for the jth class feature.

    • computes the center loss as specified in the paper.

    • outputs the resulting distance array D
  2. To compute the center loss, you need to

    • i). update W[j] using y_pred[k] according to y_true[k]=j,
    • ii). retrieve the center feature c_true[k]=W[j] for sample y_pred[k] whose y_true[k]=j
    • iii) compute the distance between y_pred and c_true.
    • Here c_true[k] = W[j], and k is the sample index, and j is the ground truth label of y_pred[k].
  3. use model.add_loss() to compute this loss. Note, don't add this loss in model.compile( loss = ... ).

Finally, you may add some loss coefficient to the center-loss if needed.

like image 82
pitfall Avatar answered Sep 24 '22 09:09

pitfall