Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

facenet triplet loss with keras

Tags:

I am trying to implement facenet in Keras with Tensorflow backend and I have some problem with the triplet loss.enter image description here

I call the fit function with 3*n number of images and then I define my custom loss function as follows:

def triplet_loss(self, y_true, y_pred):      embeddings = K.reshape(y_pred, (-1, 3, output_dim))      positive_distance = K.mean(K.square(embeddings[:,0] - embeddings[:,1]),axis=-1)     negative_distance = K.mean(K.square(embeddings[:,0] - embeddings[:,2]),axis=-1)     return K.mean(K.maximum(0.0, positive_distance - negative_distance + _alpha))  self._model.compile(loss=triplet_loss, optimizer="sgd") self._model.fit(x=x,y=y,nb_epoch=1, batch_size=len(x)) 

where y is just a dummy array filled with 0s

The problem is that even after the first iteration with batch size 20 the model starts predicting the same embedding for all the images. So when I first do the prediction on the batch every embedding is different. Then I do the fit and predict again and suddenly all the embeddings becomes almost the same for all the images in the batch

Also notice that there is a Lambda layer at the end of the model. It normalizes the output of the net so all the embeddings has a unit length as it was suggested in the face net study.

Can anybody help me out here?

Model summary

    Layer (type)                     Output Shape          Param #     Connected to                      ==================================================================================================== input_1 (InputLayer)             (None, 224, 224, 3)   0                                             ____________________________________________________________________________________________________ convolution2d_1 (Convolution2D)  (None, 112, 112, 64)  9472        input_1[0][0]                     ____________________________________________________________________________________________________ batchnormalization_1 (BatchNormal(None, 112, 112, 64)  128         convolution2d_1[0][0]             ____________________________________________________________________________________________________ maxpooling2d_1 (MaxPooling2D)    (None, 56, 56, 64)    0           batchnormalization_1[0][0]        ____________________________________________________________________________________________________ convolution2d_2 (Convolution2D)  (None, 56, 56, 64)    4160        maxpooling2d_1[0][0]              ____________________________________________________________________________________________________ batchnormalization_2 (BatchNormal(None, 56, 56, 64)    128         convolution2d_2[0][0]             ____________________________________________________________________________________________________ convolution2d_3 (Convolution2D)  (None, 56, 56, 192)   110784      batchnormalization_2[0][0]        ____________________________________________________________________________________________________ batchnormalization_3 (BatchNormal(None, 56, 56, 192)   384         convolution2d_3[0][0]             ____________________________________________________________________________________________________ maxpooling2d_2 (MaxPooling2D)    (None, 28, 28, 192)   0           batchnormalization_3[0][0]        ____________________________________________________________________________________________________ convolution2d_5 (Convolution2D)  (None, 28, 28, 96)    18528       maxpooling2d_2[0][0]              ____________________________________________________________________________________________________ convolution2d_7 (Convolution2D)  (None, 28, 28, 16)    3088        maxpooling2d_2[0][0]              ____________________________________________________________________________________________________ maxpooling2d_3 (MaxPooling2D)    (None, 28, 28, 192)   0           maxpooling2d_2[0][0]              ____________________________________________________________________________________________________ convolution2d_4 (Convolution2D)  (None, 28, 28, 64)    12352       maxpooling2d_2[0][0]              ____________________________________________________________________________________________________ convolution2d_6 (Convolution2D)  (None, 28, 28, 128)   110720      convolution2d_5[0][0]             ____________________________________________________________________________________________________ convolution2d_8 (Convolution2D)  (None, 28, 28, 32)    12832       convolution2d_7[0][0]             ____________________________________________________________________________________________________ convolution2d_9 (Convolution2D)  (None, 28, 28, 32)    6176        maxpooling2d_3[0][0]              ____________________________________________________________________________________________________ merge_1 (Merge)                  (None, 28, 28, 256)   0           convolution2d_4[0][0]                                                                                convolution2d_6[0][0]                                                                                convolution2d_8[0][0]                                                                                convolution2d_9[0][0]             ____________________________________________________________________________________________________ convolution2d_11 (Convolution2D) (None, 28, 28, 96)    24672       merge_1[0][0]                     ____________________________________________________________________________________________________ convolution2d_13 (Convolution2D) (None, 28, 28, 32)    8224        merge_1[0][0]                     ____________________________________________________________________________________________________ maxpooling2d_4 (MaxPooling2D)    (None, 28, 28, 256)   0           merge_1[0][0]                     ____________________________________________________________________________________________________ convolution2d_10 (Convolution2D) (None, 28, 28, 64)    16448       merge_1[0][0]                     ____________________________________________________________________________________________________ convolution2d_12 (Convolution2D) (None, 28, 28, 128)   110720      convolution2d_11[0][0]            ____________________________________________________________________________________________________ convolution2d_14 (Convolution2D) (None, 28, 28, 64)    51264       convolution2d_13[0][0]            ____________________________________________________________________________________________________ convolution2d_15 (Convolution2D) (None, 28, 28, 64)    16448       maxpooling2d_4[0][0]              ____________________________________________________________________________________________________ merge_2 (Merge)                  (None, 28, 28, 320)   0           convolution2d_10[0][0]                                                                               convolution2d_12[0][0]                                                                               convolution2d_14[0][0]                                                                               convolution2d_15[0][0]            ____________________________________________________________________________________________________ convolution2d_16 (Convolution2D) (None, 28, 28, 128)   41088       merge_2[0][0]                     ____________________________________________________________________________________________________ convolution2d_18 (Convolution2D) (None, 28, 28, 32)    10272       merge_2[0][0]                     ____________________________________________________________________________________________________ convolution2d_17 (Convolution2D) (None, 14, 14, 256)   295168      convolution2d_16[0][0]            ____________________________________________________________________________________________________ convolution2d_19 (Convolution2D) (None, 14, 14, 64)    51264       convolution2d_18[0][0]            ____________________________________________________________________________________________________ maxpooling2d_5 (MaxPooling2D)    (None, 14, 14, 320)   0           merge_2[0][0]                     ____________________________________________________________________________________________________ merge_3 (Merge)                  (None, 14, 14, 640)   0           convolution2d_17[0][0]                                                                               convolution2d_19[0][0]                                                                               maxpooling2d_5[0][0]              ____________________________________________________________________________________________________ convolution2d_21 (Convolution2D) (None, 14, 14, 96)    61536       merge_3[0][0]                     ____________________________________________________________________________________________________ convolution2d_23 (Convolution2D) (None, 14, 14, 32)    20512       merge_3[0][0]                     ____________________________________________________________________________________________________ maxpooling2d_6 (MaxPooling2D)    (None, 14, 14, 640)   0           merge_3[0][0]                     ____________________________________________________________________________________________________ convolution2d_20 (Convolution2D) (None, 14, 14, 256)   164096      merge_3[0][0]                     ____________________________________________________________________________________________________ convolution2d_22 (Convolution2D) (None, 14, 14, 192)   166080      convolution2d_21[0][0]            ____________________________________________________________________________________________________ convolution2d_24 (Convolution2D) (None, 14, 14, 64)    51264       convolution2d_23[0][0]            ____________________________________________________________________________________________________ convolution2d_25 (Convolution2D) (None, 14, 14, 128)   82048       maxpooling2d_6[0][0]              ____________________________________________________________________________________________________ merge_4 (Merge)                  (None, 14, 14, 640)   0           convolution2d_20[0][0]                                                                               convolution2d_22[0][0]                                                                               convolution2d_24[0][0]                                                                               convolution2d_25[0][0]            ____________________________________________________________________________________________________ convolution2d_27 (Convolution2D) (None, 14, 14, 112)   71792       merge_4[0][0]                     ____________________________________________________________________________________________________ convolution2d_29 (Convolution2D) (None, 14, 14, 32)    20512       merge_4[0][0]                     ____________________________________________________________________________________________________ maxpooling2d_7 (MaxPooling2D)    (None, 14, 14, 640)   0           merge_4[0][0]                     ____________________________________________________________________________________________________ convolution2d_26 (Convolution2D) (None, 14, 14, 224)   143584      merge_4[0][0]                     ____________________________________________________________________________________________________ convolution2d_28 (Convolution2D) (None, 14, 14, 224)   226016      convolution2d_27[0][0]            ____________________________________________________________________________________________________ convolution2d_30 (Convolution2D) (None, 14, 14, 64)    51264       convolution2d_29[0][0]            ____________________________________________________________________________________________________ convolution2d_31 (Convolution2D) (None, 14, 14, 128)   82048       maxpooling2d_7[0][0]              ____________________________________________________________________________________________________ merge_5 (Merge)                  (None, 14, 14, 640)   0           convolution2d_26[0][0]                                                                               convolution2d_28[0][0]                                                                               convolution2d_30[0][0]                                                                               convolution2d_31[0][0]            ____________________________________________________________________________________________________ convolution2d_33 (Convolution2D) (None, 14, 14, 128)   82048       merge_5[0][0]                     ____________________________________________________________________________________________________ convolution2d_35 (Convolution2D) (None, 14, 14, 32)    20512       merge_5[0][0]                     ____________________________________________________________________________________________________ maxpooling2d_8 (MaxPooling2D)    (None, 14, 14, 640)   0           merge_5[0][0]                     ____________________________________________________________________________________________________ convolution2d_32 (Convolution2D) (None, 14, 14, 192)   123072      merge_5[0][0]                     ____________________________________________________________________________________________________ convolution2d_34 (Convolution2D) (None, 14, 14, 256)   295168      convolution2d_33[0][0]            ____________________________________________________________________________________________________ convolution2d_36 (Convolution2D) (None, 14, 14, 64)    51264       convolution2d_35[0][0]            ____________________________________________________________________________________________________ convolution2d_37 (Convolution2D) (None, 14, 14, 128)   82048       maxpooling2d_8[0][0]              ____________________________________________________________________________________________________ merge_6 (Merge)                  (None, 14, 14, 640)   0           convolution2d_32[0][0]                                                                               convolution2d_34[0][0]                                                                               convolution2d_36[0][0]                                                                               convolution2d_37[0][0]            ____________________________________________________________________________________________________ convolution2d_39 (Convolution2D) (None, 14, 14, 144)   92304       merge_6[0][0]                     ____________________________________________________________________________________________________ convolution2d_41 (Convolution2D) (None, 14, 14, 32)    20512       merge_6[0][0]                     ____________________________________________________________________________________________________ maxpooling2d_9 (MaxPooling2D)    (None, 14, 14, 640)   0           merge_6[0][0]                     ____________________________________________________________________________________________________ convolution2d_38 (Convolution2D) (None, 14, 14, 160)   102560      merge_6[0][0]                     ____________________________________________________________________________________________________ convolution2d_40 (Convolution2D) (None, 14, 14, 288)   373536      convolution2d_39[0][0]            ____________________________________________________________________________________________________ convolution2d_42 (Convolution2D) (None, 14, 14, 64)    51264       convolution2d_41[0][0]            ____________________________________________________________________________________________________ convolution2d_43 (Convolution2D) (None, 14, 14, 128)   82048       maxpooling2d_9[0][0]              ____________________________________________________________________________________________________ merge_7 (Merge)                  (None, 14, 14, 640)   0           convolution2d_38[0][0]                                                                               convolution2d_40[0][0]                                                                               convolution2d_42[0][0]                                                                               convolution2d_43[0][0]            ____________________________________________________________________________________________________ convolution2d_44 (Convolution2D) (None, 14, 14, 160)   102560      merge_7[0][0]                     ____________________________________________________________________________________________________ convolution2d_46 (Convolution2D) (None, 14, 14, 64)    41024       merge_7[0][0]                     ____________________________________________________________________________________________________ convolution2d_45 (Convolution2D) (None, 7, 7, 256)     368896      convolution2d_44[0][0]            ____________________________________________________________________________________________________ convolution2d_47 (Convolution2D) (None, 7, 7, 128)     204928      convolution2d_46[0][0]            ____________________________________________________________________________________________________ maxpooling2d_10 (MaxPooling2D)   (None, 7, 7, 640)     0           merge_7[0][0]                     ____________________________________________________________________________________________________ merge_8 (Merge)                  (None, 7, 7, 1024)    0           convolution2d_45[0][0]                                                                               convolution2d_47[0][0]                                                                               maxpooling2d_10[0][0]             ____________________________________________________________________________________________________ convolution2d_49 (Convolution2D) (None, 7, 7, 192)     196800      merge_8[0][0]                     ____________________________________________________________________________________________________ convolution2d_51 (Convolution2D) (None, 7, 7, 48)      49200       merge_8[0][0]                     ____________________________________________________________________________________________________ maxpooling2d_11 (MaxPooling2D)   (None, 7, 7, 1024)    0           merge_8[0][0]                     ____________________________________________________________________________________________________ convolution2d_48 (Convolution2D) (None, 7, 7, 384)     393600      merge_8[0][0]                     ____________________________________________________________________________________________________ convolution2d_50 (Convolution2D) (None, 7, 7, 384)     663936      convolution2d_49[0][0]            ____________________________________________________________________________________________________ convolution2d_52 (Convolution2D) (None, 7, 7, 128)     153728      convolution2d_51[0][0]            ____________________________________________________________________________________________________ convolution2d_53 (Convolution2D) (None, 7, 7, 128)     131200      maxpooling2d_11[0][0]             ____________________________________________________________________________________________________ merge_9 (Merge)                  (None, 7, 7, 1024)    0           convolution2d_48[0][0]                                                                               convolution2d_50[0][0]                                                                               convolution2d_52[0][0]                                                                               convolution2d_53[0][0]            ____________________________________________________________________________________________________ convolution2d_55 (Convolution2D) (None, 7, 7, 192)     196800      merge_9[0][0]                     ____________________________________________________________________________________________________ convolution2d_57 (Convolution2D) (None, 7, 7, 48)      49200       merge_9[0][0]                     ____________________________________________________________________________________________________ maxpooling2d_12 (MaxPooling2D)   (None, 7, 7, 1024)    0           merge_9[0][0]                     ____________________________________________________________________________________________________ convolution2d_54 (Convolution2D) (None, 7, 7, 384)     393600      merge_9[0][0]                     ____________________________________________________________________________________________________ convolution2d_56 (Convolution2D) (None, 7, 7, 384)     663936      convolution2d_55[0][0]            ____________________________________________________________________________________________________ convolution2d_58 (Convolution2D) (None, 7, 7, 128)     153728      convolution2d_57[0][0]            ____________________________________________________________________________________________________ convolution2d_59 (Convolution2D) (None, 7, 7, 128)     131200      maxpooling2d_12[0][0]             ____________________________________________________________________________________________________ merge_10 (Merge)                 (None, 7, 7, 1024)    0           convolution2d_54[0][0]                                                                               convolution2d_56[0][0]                                                                               convolution2d_58[0][0]                                                                               convolution2d_59[0][0]            ____________________________________________________________________________________________________ averagepooling2d_1 (AveragePoolin(None, 1, 1, 1024)    0           merge_10[0][0]                    ____________________________________________________________________________________________________ flatten_1 (Flatten)              (None, 1024)          0           averagepooling2d_1[0][0]          ____________________________________________________________________________________________________ dense_1 (Dense)                  (None, 128)           131200      flatten_1[0][0]                   ____________________________________________________________________________________________________ lambda_1 (Lambda)                (None, 128)           0           dense_1[0][0]                     ==================================================================================================== Total params: 7456944 ____________________________________________________________________________________________________ None 
like image 469
DalekSupreme Avatar asked Dec 10 '16 13:12

DalekSupreme


People also ask

How do you do triplet losses?

A triplet loss is used in this case. is an embedding. The indices are for individual input vectors given as a triplet. The triplet is formed by drawing an anchor input, a positive input that describes the same entity as the anchor entity, and a negative input that does not describe the same entity as the anchor entity.

How do you train a FaceNet model?

To train the model we want our images to have same size and they must contain faces only. To get training data we will use a face detection algorithm called Multi-task Cascaded Convolutional Neural Networks (MTCNN). Use the script named align_dataset_mtcnn.py to align faces. This code is taken from facenet.


2 Answers

What could have happened, other than the learning rate was simply too high, was that an unstable triplet selection strategy had been used, effectively. If, for example, you only use 'hard triplets' (triplets where the a-n distance is smaller than the a-p distance), your network weights might collapse all embeddings to a single point (making the loss always equal to margin (your _alpha), because all embedding distances are zero).

This can be fixed by using other kinds of triplets as well (like 'semi-hard triplets' where a-p is smaller than a-n, but the distance between a-p and a-n is still smaller than margin). So maybe if you always checked for this... It is explained in more detail in this blog post: https://omoindrot.github.io/triplet-loss

like image 116
Jana Cavojska Avatar answered Oct 09 '22 14:10

Jana Cavojska


Are you constraining your embeddings to "be on a d-dimensional hypersphere"? Try running tf.nn.l2_normalize on your embeddings right after they come out of the CNN.

The problem could be that the embeddings are sort of being smart-alecs. One easy way to reduce the loss is to just set everything to zero. l2_normalize forces them to be unit length.

It looks you'll want to add the normalizing right after the last average pool.

like image 32
chris Avatar answered Oct 09 '22 15:10

chris