Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Triplet model for image retrieval from the Keras pretrained network

I want to implement a model for image retrieval. The model will be trained with a triplet loss function (same as facenet or similar architectures). My idea was to use a pretrained classification model from Keras (e.g. resnet50), and make it a triple architecture. This is my model in Keras:

resnet_input = Input(shape=(224,224,3))
resnet_model = ResNet50(weights='imagenet', include_top = False, input_tensor=resnet_input)
net = resnet_model.output

net = Flatten(name='flatten')(net) 
net = Dense(512, activation='relu', name='embded')(net)
net = Lambda(l2Norm, output_shape=[512])(net)

base_model = Model(resnet_model.input, net, name='resnet_model')

input_shape=(224,224,3)
input_anchor = Input(shape=input_shape, name='input_anchor')
input_positive = Input(shape=input_shape, name='input_pos')
input_negative = Input(shape=input_shape, name='input_neg')

net_anchor = base_model(input_anchor)
net_positive = base_model(input_positive)
net_negative = base_model(input_negative)

positive_dist = Lambda(euclidean_distance, name='pos_dist')([net_anchor, net_positive])
negative_dist = Lambda(euclidean_distance, name='neg_dist')([net_anchor, net_negative])

stacked_dists = Lambda( 
            lambda vects: K.stack(vects, axis=1),
            name='stacked_dists'
)([positive_dist, negative_dist])

model = Model([input_anchor, input_positive, input_negative], stacked_dists, name='triple_siamese')

def triplet_loss(_, y_pred):
    margin = K.constant(1)
    return K.mean(K.maximum(K.constant(0), K.square(y_pred[0]) - K.square(y_pred[1]) + margin))

def accuracy(_, y_pred):
    return K.mean(y_pred[0] < y_pred[1])

def l2Norm(x):
    return  K.l2_normalize(x, axis=-1)

def euclidean_distance(vects):
    x, y = vects
    return K.sqrt(K.maximum(K.sum(K.square(x - y), axis=1, keepdims=True), K.epsilon()))

The model should predict a feature vector for every image. The distances (euclidean in this case) between these vectors should be close to zero if the images are from the same class, or close to one if they're not.

I already tried different learning steps, batch sizes, different margins in the loss function, choosing different output layers from the original resnet model, different layers added to the end of the resnet, training just the new added layers vs training the whole model. I also tried to use this resnet model without pretrained weights, the result was still the same, around 0.5 accuracy and 1.0 loss, no matter what I did. (The input images were preprocessed in the way which was expected for this model with keras.applications.resnet50.preprocess_input)

I didn't do any hard negative mining which may result in slow convergence, but still, 0.5 accuracy in this case (check the function) is random predictions.

So I started thinking maybe I miss something really important here (it's a quite difficult architecture). So please, if you notice something wrong or suspicious in my implementation I'd be really glad for that.

like image 746
T.Poe Avatar asked Dec 09 '17 10:12

T.Poe


People also ask

What is triplet neural network?

It's a kind of neural network architecture where multiple parallel networks are trained that share weights among each other. During prediction time, input data is passed through one network to compute distributed embeddings representation of input data.

How do you train a Siamese network with Triplet Loss?

you can train the network by taking an anchor image and comparing it with both a positive sample and a negative sample. The dissimilarity between the anchor image and positive image must low and the dissimilarity between the anchor image and the negative image must be high.

Why is Triplet Loss better than contrastive loss?

Additionally, Triplet Loss is less greedy. Unlike Contrastive Loss, it is already satisfied when different samples are easily distinguishable from similar ones. It does not change the distances in a positive cluster if there is no interference from negative examples.


2 Answers

In case someone'd be interested, rewriting

y_pred[0] and y_pred[1]

to

y_pred[:,0,0] and y_pred[:,1,0]

fixed it.

Now the model seems to be training (the loss is decreasing and the accuracy is increasing).

like image 188
T.Poe Avatar answered Oct 19 '22 10:10

T.Poe


I don't have enough reputation points to comment so I am writing it like this.

I would like to do something similar to what you are doing by using your code.

I am new in CNN and I am not sure how my training data should look like. Would you be willing to share the remainder of your code? I would be very grateful!

Edit:

To answer my own question which may be useful for someone, this is how I did on holidays photos (http://lear.inrialpes.fr/%7Ejegou/data.php) and it works:

def get_random_image(img_groups, group_names, gid):
    gname = group_names[gid]
    photos = img_groups[gname]
    pid = np.random.choice(np.arange(len(photos)), size=1)[0]
    pname = photos[pid]
    return gname + pname + ".jpg"

def create_triples(image_dir):
    img_groups = {}
    for img_file in os.listdir(image_dir):
        prefix, suffix = img_file.split(".")
        gid, pid = prefix[0:4], prefix[4:]

        if gid in img_groups.keys():
            img_groups[gid].append(pid)
        else:
            img_groups[gid] = [pid]
    pos_triples, neg_triples = [], []

    for key in img_groups.keys():
        triples = [(key + x[0] + ".jpg", key + x[1] + ".jpg", str(int(key)+3 if int(key)<1495 else int(key)-3)+'01'+'.jpg')
                 for x in itertools.combinations(img_groups[key], 2)]
        pos_triples.extend(triples)

    return pos_triples

def triplet_loss(y_true, y_pred):
        margin = K.constant(0.2)
        return K.mean(K.maximum(K.constant(0), K.square(y_pred[:,0,0]) - K.square(y_pred[:,1,0]) + margin))

def accuracy(y_true, y_pred):
    return K.mean(y_pred[:,0,0] < y_pred[:,1,0])

def l2Norm(x):
    return  K.l2_normalize(x, axis=-1)

def euclidean_distance(vects):
    x, y = vects
    return K.sqrt(K.maximum(K.sum(K.square(x - y), axis=1, keepdims=True), K.epsilon()))

triples_data = create_triples(IMAGE_DIR)


dim = 1500
h = 299
w= 299
anchor =np.zeros((dim,h,w,3))
positive =np.zeros((dim,h,w,3))
negative =np.zeros((dim,h,w,3))


for n,val in enumerate(triples_data[0:1500]):
    image_anchor = plt.imread(os.path.join(IMAGE_DIR, val[0]))
    image_anchor = imresize(image_anchor, (h, w))    
    image_anchor = image_anchor.astype("float32")
    #image_anchor = image_anchor/255.
    image_anchor = keras.applications.resnet50.preprocess_input(image_anchor, data_format='channels_last')
    anchor[n] = image_anchor

    image_positive = plt.imread(os.path.join(IMAGE_DIR, val[1]))
    image_positive = imresize(image_positive, (h, w))
    image_positive = image_positive.astype("float32")
    #image_positive = image_positive/255.
    image_positive = keras.applications.resnet50.preprocess_input(image_positive, data_format='channels_last')
    positive[n] = image_positive

    image_negative = plt.imread(os.path.join(IMAGE_DIR, val[2]))
    image_negative = imresize(image_negative, (h, w))
    image_negative = image_negative.astype("float32")
    #image_negative = image_negative/255.
    image_negative = keras.applications.resnet50.preprocess_input(image_negative, data_format='channels_last')
    negative[n] = image_negative

Y_train = np.random.randint(2, size=(1,2,dim)).T


resnet_input = Input(shape=(h,w,3))
resnet_model = ResNet50(weights='imagenet', include_top = False, input_tensor=resnet_input)


for layer in resnet_model.layers:
    layer.trainable = False  


net = resnet_model.output
net = Flatten(name='flatten')(net) 
net = Dense(128, activation='relu', name='embed')(net)
net = Dense(128, activation='relu', name='embed2')(net)
net = Dense(128, activation='relu', name='embed3')(net)
net = Lambda(l2Norm, output_shape=[128])(net)

base_model = Model(resnet_model.input, net, name='resnet_model')

input_shape=(h,w,3)
input_anchor = Input(shape=input_shape, name='input_anchor')
input_positive = Input(shape=input_shape, name='input_pos')
input_negative = Input(shape=input_shape, name='input_neg')

net_anchor = base_model(input_anchor)
net_positive = base_model(input_positive)
net_negative = base_model(input_negative)

positive_dist = Lambda(euclidean_distance, name='pos_dist')([net_anchor, net_positive])
negative_dist = Lambda(euclidean_distance, name='neg_dist')([net_anchor, net_negative])

stacked_dists = Lambda( 
            lambda vects: K.stack(vects, axis=1),
            name='stacked_dists'
)([positive_dist, negative_dist])


model = Model([input_anchor, input_positive, input_negative], stacked_dists, name='triple_siamese')

model.compile(optimizer="rmsprop", loss=triplet_loss, metrics=[accuracy])

model.fit([anchor, positive, negative], Y_train, epochs=50,  batch_size=15, validation_split=0.2)

model.save('triplet_loss_resnet50.h5')
like image 32
bmorvaj Avatar answered Oct 19 '22 11:10

bmorvaj