Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Variationnal auto-encoder: implementing warm-up in Keras

I recently read this paper which introduces a process called "Warm-Up" (WU), which consists in multiplying the loss in the KL-divergence by a variable whose value depends on the number of epoch (it evolves linearly from 0 to 1)

I was wondering if this is the good way to do that:

beta = K.variable(value=0.0)

def vae_loss(x, x_decoded_mean):
    # cross entropy
    xent_loss = K.mean(objectives.categorical_crossentropy(x, x_decoded_mean))

    # kl divergence
    for k in range(n_sample):
        epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0.,
                              std=1.0)  # used for every z_i sampling
        # Sample several layers of latent variables
        for mean, var in zip(means, variances):
            z_ = mean + K.exp(K.log(var) / 2) * epsilon

            # build z
            try:
                z = tf.concat([z, z_], -1)
            except NameError:
                z = z_
            except TypeError:
                z = z_

            # sum loss (using a MC approximation)
            try:
                loss += K.sum(log_normal2(z_, mean, K.log(var)), -1)
            except NameError:
                loss = K.sum(log_normal2(z_, mean, K.log(var)), -1)
        print("z", z)
        loss -= K.sum(log_stdnormal(z) , -1)
        z = None
    kl_loss = loss / n_sample
    print('kl loss:', kl_loss)

    # result
    result = beta*kl_loss + xent_loss
    return result

# define callback to change the value of beta at each epoch
def warmup(epoch):
    value = (epoch/10.0) * (epoch <= 10.0) + 1.0 * (epoch > 10.0)
    print("beta:", value)
    beta = K.variable(value=value)

from keras.callbacks import LambdaCallback
wu_cb = LambdaCallback(on_epoch_end=lambda epoch, log: warmup(epoch))


# train model
vae.fit(
    padded_X_train[:last_train,:,:],
    padded_X_train[:last_train,:,:],
    batch_size=batch_size,
    nb_epoch=nb_epoch,
    verbose=0,
    callbacks=[tb, wu_cb],
    validation_data=(padded_X_test[:last_test,:,:], padded_X_test[:last_test,:,:])
)
like image 680
sbaur Avatar asked Mar 14 '17 13:03

sbaur


1 Answers

This will not work. I tested it to figure out exactly why it was not working. The key thing to remember is that Keras creates a static graph once at the beginning of training.

Therefore, the vae_loss function is called only once to create the loss tensor, which means that the reference to the beta variable will remain the same every time the loss is calculated. However, your warmup function reassigns beta to a new K.variable. Thus, the beta that is used for calculating loss is a different beta than the one that gets updated, and the value will always be 0.

It is an easy fix. Just change this line in your warmup callback:

beta = K.variable(value=value)

to:

K.set_value(beta, value)

This way the actual value in beta gets updated "in place" rather than creating a new variable, and the loss will be properly re-calculated.

like image 68
Nigel Avatar answered Nov 16 '22 06:11

Nigel