Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tracking Multiple losses with Keras

Tags:

keras

For networks like VAEs with competing losses, it's useful to keep track of each loss independently. That is, it's useful to see a total loss, as well as the data term and KL-code terms.

Is this something that is possible in Keras? It's possible to recover the losses with vae.losses, but they are tensorflow layers, and thus can't be used in keras (eg can't create a second model that computes vae losses as output).

It seems like a way to do this would be to add them to the metrics list on compile, but they don't fit the model of metrics.

Here's some sample code, sorry for the length, it's mildly adapted from the example code from Keras. The major difference is that I've explicitly moved computation of the KL div to a sampling layer, which feels more natural than the original sample code.

'''This script demonstrates how to build a variational autoencoder with Keras.

Reference: "Auto-Encoding Variational Bayes" https://arxiv.org/abs/1312.6114
'''    
from keras.layers import Input, Dense, Lambda, Layer
from keras.models import Model
from keras import backend as K
from keras import metrics

batch_size = 100
original_dim = 784
latent_dim = 2
intermediate_dim = 256
epochs = 50
epsilon_std = 1.0


x = Input(batch_shape=(batch_size, original_dim))
h = Dense(intermediate_dim, activation='relu')(x)
z_mean = Dense(latent_dim)(h)
z_log_var = Dense(latent_dim)(h)

class CustomSamplingLayer(Layer):
    def __init__(self, **kwargs):
        super(CustomSamplingLayer, self).__init__(**kwargs)

    def kl_div_loss(self, z_mean, z_log_var):
        kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
        return K.mean(kl_loss)

    def call(self, inputs):
        z_mean = inputs[0]
        z_log_var = inputs[1]
        loss = self.kl_div_loss(z_mean, z_log_var)
        self.add_loss(loss, inputs=inputs)
        epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0.,
                                  stddev=epsilon_std)
        return z_mean + K.exp(z_log_var / 2) * epsilon

# note that "output_shape" isn't necessary with the TensorFlow backend
z = CustomSamplingLayer()([z_mean, z_log_var])

# we instantiate these layers separately so as to reuse them later
decoder_h = Dense(intermediate_dim, activation='relu')
decoder_mean = Dense(original_dim, activation='sigmoid')
h_decoded = decoder_h(z)
x_decoded_mean = decoder_mean(h_decoded)

# Custom loss layer
class CustomVariationalLayer(Layer):
    def __init__(self, **kwargs):
        self.is_placeholder = True
        super(CustomVariationalLayer, self).__init__(**kwargs)

    def vae_loss(self, x, x_decoded_mean):
        xent_loss = original_dim * metrics.binary_crossentropy(x, x_decoded_mean)
        return K.mean(xent_loss)

    def call(self, inputs):
        x = inputs[0]
        x_decoded_mean = inputs[1]
        loss = self.vae_loss(x, x_decoded_mean)
        self.add_loss(0.0 * loss, inputs=inputs)
        return x_decoded_mean
y = CustomVariationalLayer()([x, x_decoded_mean])
vae = Model(x, y)
vae.compile(optimizer='rmsprop', loss=None)
like image 990
jrock Avatar asked Oct 30 '22 07:10

jrock


1 Answers

I attempted something like this on the gumbel-softmax (categorical) VAE implemented in Keras here. The trick for me was using metrics, like you suggested. Here's the setup for the model:

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from keras.layers import Input, Dense, Lambda
from keras.models import Model, Sequential
from keras import backend as K
from keras.datasets import mnist
from keras.activations import softmax
from keras.objectives import binary_crossentropy as bce


batch_size = 200
data_dim = 784
M = 10
N = 10
nb_epoch = 3
epsilon_std = 0.01

tmp = []

anneal_rate = 0.0003
min_temperature = 0.5

tau = K.variable(5.0, name="temperature")
x = Input(batch_shape=(batch_size, data_dim))
h = Dense(256, activation='relu')(Dense(512, activation='relu')(x))
logits_y = Dense(M*N)(h)

def sampling(logits_y):
    U = K.random_uniform(K.shape(logits_y), 0, 1)
    y = logits_y - K.log(-K.log(U + 1e-20) + 1e-20)
    y = softmax(K.reshape(y, (-1, N, M)) / tau)
    y = K.reshape(y, (-1, N*M))
    return y

z = Lambda(sampling, output_shape=(M*N,))(logits_y)
generator = Sequential()
generator.add(Dense(256, activation='relu', input_shape=(N*M, )))
generator.add(Dense(512, activation='relu'))
generator.add(Dense(data_dim, activation='sigmoid'))
x_hat = generator(z)

Here I define the total loss for model optimization, followed by individual functions for the components. Note that KL_loss takes two arguments that aren't used. Keras will throw an exception if your metric function doesn't take these two arguments.

def gumbel_loss(x, x_hat):
    q_y = K.reshape(logits_y, (-1, N, M))
    q_y = softmax(q_y)
    log_q_y = K.log(q_y + 1e-20)
    kl_tmp = q_y * (log_q_y - K.log(1.0/M))
    KL = K.sum(kl_tmp, axis=(1, 2))
    elbo = data_dim * bce(x, x_hat) - KL
    return elbo

def KL_loss(y_true, y_pred):
    q_y = K.reshape(logits_y, (-1, N, M))
    q_y = softmax(q_y)
    log_q_y = K.log(q_y + 1e-20)
    kl_tmp = q_y * (log_q_y - K.log(1.0/M))
    KL = K.sum(kl_tmp, axis=(1, 2))
    return K.mean(-KL)

def bce_loss(y_true, y_pred):
    return K.mean(data_dim * bce(y_true, y_pred))

Then compiled and ran.

vae = Model(x, x_hat)
vae.compile(optimizer='adam', loss=gumbel_loss,
            metrics = [KL_loss, bce_loss])

# train the VAE on MNIST digits
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

for e in range(nb_epoch):
    vae.fit(x_train, x_train,
        shuffle=True,
        epochs=1,
        batch_size=batch_size,
        validation_data=(x_test, x_test))
    out = vae.predict(x_test, batch_size = batch_size)
    K.set_value(tau, np.max([K.get_value(tau) * np.exp(- anneal_rate * e), min_temperature]))

I experimented with callbacks and lots of other things before figuring this out, so hopefully it helps.

like image 118
Nicholas Normandin Avatar answered Jan 02 '23 21:01

Nicholas Normandin