A recent paper (here) introduced a secondary loss function that they called center loss. It is based on the distance between the embeddings in a batch and the running average embedding for each of the respective classes. There has been some discussion in the TF Google groups (here) regarding how such embedding centers can be computed and updated. I've put together some code to generate class-average embeddings in my answer below.
Is this the best way to do this?
The previously posted method is too simple for cases like center loss where the expected value of the embeddings change over time as the model becomes more refined. This is because the previous center-finding routine averages all instances since start and therefore tracks changes in expected value very slowly. Instead, a moving window average is preferred. An exponential moving-window variant is as follows:
def get_embed_centers(embed_batch, label_batch):
''' Exponential moving window average. Increase decay for longer windows [0.0 1.0]
'''
decay = 0.95
with tf.variable_scope('embed', reuse=True):
embed_ctrs = tf.get_variable("ctrs")
label_batch = tf.reshape(label_batch, [-1])
old_embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
dif = (1 - decay) * (old_embed_ctrs_batch - embed_batch)
embed_ctrs = tf.scatter_sub(embed_ctrs, label_batch, dif)
embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
return embed_ctrs_batch
with tf.Session() as sess:
with tf.variable_scope('embed'):
embed_ctrs = tf.get_variable("ctrs", [nclass, ndims], dtype=tf.float32,
initializer=tf.constant_initializer(0), trainable=False)
label_batch_ph = tf.placeholder(tf.int32)
embed_batch_ph = tf.placeholder(tf.float32)
embed_ctrs_batch = get_embed_centers(embed_batch_ph, label_batch_ph)
sess.run(tf.initialize_all_variables())
tf.get_default_graph().finalize()
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With