Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Generating MNIST numbers using LSTM-CGAN in TensorFlow

Tags:

Inspired by this article, I'm trying to build a Conditional GAN which will use LSTM to generate MNIST numbers. I hope I'm using the same architecture as in the image below (except for the bidirectional RNN in discriminator, taken from this paper):

enter image description here

When I run this model, I've got very strange results. This image shows my model generating number 3 after each epoch. It should look more like this. It's really bad.

enter image description here

Loss of my discriminator network decreasing really fast up to close to zero. However, the loss of my generator network oscillates around some fixed point (maybe diverging slowly). I really don't know what's happening. Here is the most important part of my code (full code here):

timesteps = 28
X_dim = 28
Z_dim = 100
y_dim = 10

X = tf.placeholder(tf.float32, [None, timesteps, X_dim]) # reshaped MNIST image to 28x28
y = tf.placeholder(tf.float32, [None, y_dim]) # one-hot label
Z = tf.placeholder(tf.float32, [None, timesteps, Z_dim]) # numpy.random.uniform noise in range [-1; 1]

y_timesteps = tf.tile(tf.expand_dims(y, axis=1), [1, timesteps, 1]) # [None, timesteps, y_dim] - replicate y along axis=1

def discriminator(x, y):
    with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE) as vs:
        inputs = tf.concat([x, y], axis=2)
        D_cell = tf.contrib.rnn.LSTMCell(64)
        output, _ = tf.nn.dynamic_rnn(D_cell, inputs, dtype=tf.float32)
        last_output = output[:, -1, :]
        logit = tf.contrib.layers.fully_connected(last_output, 1, activation_fn=None)
        pred = tf.nn.sigmoid(logit)
        variables = [v for v in tf.all_variables() if v.name.startswith(vs.name)]
        return variables, pred, logit

def generator(z, y):
    with tf.variable_scope('generator', reuse=tf.AUTO_REUSE) as vs:
        inputs = tf.concat([z, y], axis=2)
        G_cell = tf.contrib.rnn.LSTMCell(64)
        output, _ = tf.nn.dynamic_rnn(G_cell, inputs, dtype=tf.float32)
        logit = tf.contrib.layers.fully_connected(output, X_dim, activation_fn=None)
        pred = tf.nn.sigmoid(logit)
        variables = [v for v in tf.all_variables() if v.name.startswith(vs.name)]
        return variables, pred

G_vars, G_sample = run_generator(Z, y_timesteps)
D_vars, D_real, D_logit_real = run_discriminator(X, y_timesteps)
_, D_fake, D_logit_fake = run_discriminator(G_sample, y_timesteps)

D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
G_loss = -tf.reduce_mean(tf.log(D_fake))

D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=D_vars)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=G_vars)

There is most likely something wrong with my model. Anyone could help me make the generator network converge?

like image 652
user1518183 Avatar asked Feb 25 '18 01:02

user1518183


1 Answers

There are a few things you can do to improve your network architecture and training phase.

  1. Remove the tf.nn.sigmoid(logit) from both the generator and discriminator. Return just the pred.
  2. Use a numerically stable function to calculate your loss functions and fix the loss functions:

    D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake)) G_loss = -tf.reduce_mean(tf.log(D_fake))

should be:

D_loss_real = tf.nn.sigmoid_cross_entropy_with_logits(
              logits=D_real,
              labels=tf.ones_like(D_real))
D_loss_fake = tf.nn.sigmoid_cross_entropy_with_logits(
              logits=D_fake,
              labels=tf.zeros_like(D_fake))

D_loss = -tf.reduce_mean(D_loss_real + D_loss_fake)
G_loss = -tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
              logits=D_real,
              labels=tf.ones_like(D_real)))

Once you fixed the loss and used a numerically stable function, things will go better. Also, as a rule of thumb, if there's too much noise in the loss, reduce the learning rate (the default lr of ADAM is usually too high when training GANs). Hope it helps

like image 168
nessuno Avatar answered Oct 07 '22 13:10

nessuno