Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to Implement Center Loss and Other Running Averages of Labeled Embeddings

Tags:

tensorflow

A recent paper (here) introduced a secondary loss function that they called center loss. It is based on the distance between the embeddings in a batch and the running average embedding for each of the respective classes. There has been some discussion in the TF Google groups (here) regarding how such embedding centers can be computed and updated. I've put together some code to generate class-average embeddings in my answer below.

Is this the best way to do this?

like image 785
RobR Avatar asked Oct 20 '16 18:10

RobR


1 Answers

The previously posted method is too simple for cases like center loss where the expected value of the embeddings change over time as the model becomes more refined. This is because the previous center-finding routine averages all instances since start and therefore tracks changes in expected value very slowly. Instead, a moving window average is preferred. An exponential moving-window variant is as follows:

def get_embed_centers(embed_batch, label_batch):
    ''' Exponential moving window average. Increase decay for longer windows [0.0 1.0]
    '''
    decay = 0.95
    with tf.variable_scope('embed', reuse=True):
        embed_ctrs = tf.get_variable("ctrs")

    label_batch = tf.reshape(label_batch, [-1])
    old_embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
    dif = (1 - decay) * (old_embed_ctrs_batch - embed_batch)
    embed_ctrs = tf.scatter_sub(embed_ctrs, label_batch, dif)
    embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
    return embed_ctrs_batch


with tf.Session() as sess:
    with tf.variable_scope('embed'):
        embed_ctrs = tf.get_variable("ctrs", [nclass, ndims], dtype=tf.float32,
                        initializer=tf.constant_initializer(0), trainable=False)
    label_batch_ph = tf.placeholder(tf.int32)
    embed_batch_ph = tf.placeholder(tf.float32)
    embed_ctrs_batch = get_embed_centers(embed_batch_ph, label_batch_ph)
    sess.run(tf.initialize_all_variables())
    tf.get_default_graph().finalize()
like image 51
RobR Avatar answered Sep 28 '22 05:09

RobR