Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

None dimension raise ValueError in batch_norm with Tensorflow

I have implemented a certain kind of neural networks (GAN: Generative Adversarial Networks) in tensorflow.

It worked as expected until I decided to add the following batch normalization layer in the generator(z) method (see full code below):

out = tf.contrib.layers.batch_norm(out, is_training=False)

as I get the following error:

    G_sample = generator(Z)
  File "/Users/Florian/Documents/DeepLearning/tensorflow_stuff/tensorflow_stuff/DCGAN.py", line 84, in generator
    out = tf.contrib.layers.batch_norm(out, is_training=False)                                    
  File "/Users/Florian/anaconda2/lib/python2.7/site-packages/tensorflow/contrib/framework/python/ops/arg_scope.py", line 181, in func_with_args
    return func(*args, **current_args)
  File "/Users/Florian/anaconda2/lib/python2.7/site-packages/tensorflow/contrib/layers/python/layers/layers.py", line 551, in batch_norm
    outputs = layer.apply(inputs, training=is_training)
  File "/Users/Florian/anaconda2/lib/python2.7/site-packages/tensorflow/python/layers/base.py", line 381, in apply
    return self.__call__(inputs, **kwargs)
  File "/Users/Florian/anaconda2/lib/python2.7/site-packages/tensorflow/python/layers/base.py", line 328, in __call__
    self.build(input_shapes[0])
  File "/Users/Florian/anaconda2/lib/python2.7/site-packages/tensorflow/python/layers/normalization.py", line 143, in build
    input_shape)
ValueError: ('Input has undefined `axis` dimension. Input shape: ', TensorShape([Dimension(None), Dimension(None), Dimension(None), Dimension(None)]))

The problem seems to be due to the [None, None, None, None] shape of the input out but I do not know how I can fix this.

Here is the full code:

from __future__ import division
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.layers import batch_norm 
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os




def leaky_relu(x, alpha):
    return tf.maximum(alpha * x, x)




def discriminator(x):

    with tf.variable_scope('discriminator', reuse=True):

        # conv_2D accepts shape (batch, height, width, channel) as input so
        # reshape it
        x = tf.reshape(x, shape=[-1, 28, 28, 1])
        out = tf.nn.conv2d(x, tf.get_variable('D_w_1'), strides=[1, 2, 2, 1], padding='SAME') 
        out = leaky_relu(out, alpha=0.2)
        #out = tf.nn.dropout(out, keep_prob=0.2)
        out = tf.nn.conv2d(out, tf.get_variable('D_w_2'), strides=[1, 2, 2, 1], padding='SAME') 
        out = leaky_relu(out, alpha=0.2)
        #out = tf.nn.dropout(out, keep_prob=0.2)

        # fully connected layer
        out = tf.reshape(out, shape=[-1, 7*7*128])
        D_logits = tf.matmul(out, tf.get_variable('D_w_fc_1'))
        #D_logits = tf.nn.sigmoid(D_logits)
        D_logits = leaky_relu(D_logits, alpha=0.2)

    return D_logits




def generator(z):

    with tf.variable_scope('generator', reuse=True):
        out = tf.matmul(z, tf.get_variable('G_w_fc_1'))
        out = tf.nn.relu(out)

        out = tf.reshape(out, shape=[-1, 7, 7, 128])

        out = tf.nn.conv2d_transpose(out, 
                                     tf.get_variable('G_w_deconv_1'), 
                                     output_shape=tf.stack([tf.shape(out)[0], 14, 14, 64]),
                                     strides=[1, 2, 2, 1],
                                     padding='SAME') 
        print(out.get_shape().as_list())
        out = tf.contrib.layers.batch_norm(out, is_training=False)                                    
        out = tf.nn.relu(out)

        out = tf.nn.conv2d_transpose(out, 
                                     tf.get_variable('G_w_deconv_2'), 
                                     output_shape=tf.stack([tf.shape(out)[0], 28, 28, 1]),
                                     strides=[1, 2, 2, 1],
                                     padding='SAME') 
        out = tf.nn.tanh(out)


    return out







def sample_Z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])


def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig


if __name__ == '__main__':


    mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)

    batch_size = 128
    # size of generator input
    Z_dim = 10 
    # batch within an epoch
    batches_per_epoch = int(np.floor(mnist.train.num_examples / batch_size))
    nb_epochs = 20

    # learning rate
    learning_rate = 0.00005 # 0.0002

    Z = tf.placeholder(tf.float32, [batch_size, Z_dim])
    X = tf.placeholder(tf.float32, [batch_size, 784])

    with tf.variable_scope('discriminator'):
        D_w_1 = tf.get_variable('D_w_1', initializer=tf.random_normal([5, 5, 1, 64], stddev=0.02))
        D_w_2 = tf.get_variable('D_w_2', initializer=tf.random_normal([5, 5, 64, 128], stddev=0.02))
        D_w_fc_1 = tf.get_variable('D_w_fc_1', initializer=tf.random_normal([7*7*128, 1], stddev=0.02)) 

    D_var_list = [D_w_1, D_w_2, D_w_fc_1]


    with tf.variable_scope('generator'):
        G_w_fc_1 = tf.get_variable('G_w_fc_1', initializer=tf.random_normal([Z_dim, 128*7*7], stddev=0.02))
        G_w_deconv_1 = tf.get_variable('G_w_deconv_1', initializer=tf.random_normal([5, 5, 64, 128], stddev=0.02))
        G_w_deconv_2 = tf.get_variable('G_w_deconv_2', initializer=tf.random_normal([5, 5, 1, 64], stddev=0.02))

    G_var_list = [G_w_fc_1, G_w_deconv_1, G_w_deconv_2]


    G_sample = generator(Z)
    D_logit_real = discriminator(X)
    D_logit_fake = discriminator(G_sample)


    # objective functions
    # discriminator aims at maximizing the probability of TRUE data (i.e. from the dataset) and minimizing the probability
    # of GENERATED/FAKE data:
    D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
    D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
    D_loss = D_loss_real + D_loss_fake

    # generator aims at maximizing the probability of GENERATED/FAKE data (i.e. fool the discriminator)
    G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))

    D_solver = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(D_loss, var_list=D_var_list)
    # when optimizing generator, discriminator is kept fixed
    G_solver = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(G_loss, var_list=G_var_list)


    with tf.Session() as sess:    

        sess.run(tf.global_variables_initializer())

        if not os.path.exists('out/'):
            os.makedirs('out/')

        for i_epoch in range(nb_epochs):

            G_loss_val = 0
            D_loss_val = 0

            for i_batch in range(batches_per_epoch):
                print('batch %i/%i' % (i_batch+1, batches_per_epoch))

                X_mb, _ = mnist.train.next_batch(batch_size)

                # train discriminator
                _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(batch_size, Z_dim)})
                D_loss_val += D_loss_curr

                # train generator
                _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(batch_size, Z_dim)})
                G_loss_val += G_loss_curr

                if i_batch % 50 == 0:
                    samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})

                    fig = plot(samples)
                    plt.savefig('out/%i_%i.png' % (i_epoch, i_batch), bbox_inches='tight')
                    plt.close(fig)





            print('Iter: {}'.format(i_epoch))
            print('D loss: {:.4}'.format(D_loss))
            print('G_loss: {:.4}'.format(G_loss))
like image 691
floflo29 Avatar asked Apr 18 '17 19:04

floflo29


2 Answers

If you pass a constant shape such as [100, 14, 14, 64] as output_shape, conv2d_transpose will return a tensor with the correct shape set. But if you pass in a non-constant tensor (which you have to do if you don't know the batch size in advance), conv2d_transpose assumes that it can't know the shape until the graph is run, and returns an all-None shape during construction.

Theoretically it could have realized that some of the dimensions are constant, but this isn't done at the moment.

You can work around this by using out.set_shape([None, 14, 14, 64]), or out = tf.reshape(out, [-1, 14, 14, 64]). There is no need to set the size of the batch dimension as batch_norm does not require it.

This is discussed on tensorflow issues 833 and 8972.

like image 53
interjay Avatar answered Nov 16 '22 21:11

interjay


Working code is below. There were a few minor errors in the code - maybe from your testing before posting the question - or maybe you hadn't executed it fully yet, my edits are noted with #EDIT: . You need to define the shape in order to use Batch normalization, and you could do this in advance if you wanted to but your suggestion is fine. I prefer to use reshape with a variable dimension by using -1 out = tf.reshape(out, [-1, 14, 14, 64]). The code below works on TF > 1 and python > 3.5.

from __future__ import division
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.layers import batch_norm 
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os




def leaky_relu(x, alpha):
    return tf.maximum(alpha * x, x)




def discriminator(x):

    with tf.variable_scope('discriminator', reuse=True):

        # conv_2D accepts shape (batch, height, width, channel) as input so
        # reshape it
        x = tf.reshape(x, shape=[-1, 28, 28, 1])
        out = tf.nn.conv2d(x, tf.get_variable('D_w_1'), strides=[1, 2, 2, 1], padding='SAME') 
        out = leaky_relu(out, alpha=0.2)
        #out = tf.nn.dropout(out, keep_prob=0.2)
        out = tf.nn.conv2d(out, tf.get_variable('D_w_2'), strides=[1, 2, 2, 1], padding='SAME') 
        out = leaky_relu(out, alpha=0.2)
        #out = tf.nn.dropout(out, keep_prob=0.2)

        # fully connected layer
        out = tf.reshape(out, shape=[-1, 7*7*128])
        D_logits = tf.matmul(out, tf.get_variable('D_w_fc_1'))
        #D_logits = tf.nn.sigmoid(D_logits)
        D_logits = leaky_relu(D_logits, alpha=0.2)

    return D_logits




def generator(z):

    with tf.variable_scope('generator', reuse=True):
        out = tf.matmul(z, tf.get_variable('G_w_fc_1'))
        out = tf.nn.relu(out)

        out = tf.reshape(out, shape=[-1, 7, 7, 128])

        out = tf.nn.conv2d_transpose(out, 
                                     tf.get_variable('G_w_deconv_1'),
                                     output_shape=tf.stack([tf.shape(out)[0], 14, 14, 64]),
                                     strides=[1, 2, 2, 1],
                                     padding='SAME') 
        print(out.get_shape().as_list())
        out = tf.reshape(out, [-1, 14, 14, 64])   #EDIT: You need to define the shape for batch_norm

        #out.set_shape([out.get_shape().as_list()[0], 14, 14, 64])
        out = tf.contrib.layers.batch_norm(out, is_training=False)
        out = tf.nn.relu(out)

        out = tf.nn.conv2d_transpose(out, 
                                     tf.get_variable('G_w_deconv_2'), 
                                     output_shape=tf.stack([tf.shape(out)[0], 28, 28, 1]),
                                     strides=[1, 2, 2, 1],
                                     padding='SAME') 
        out = tf.nn.tanh(out)


    return out

def sample_Z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])


def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(12, 12)  #EDIT:  This wasn't large enough for the dataset.
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig


if __name__ == '__main__':


    mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)

    batch_size = 128
    # size of generator input
    Z_dim = 10 
    # batch within an epoch
    batches_per_epoch = int(np.floor(mnist.train.num_examples / batch_size))
    nb_epochs = 20

    # learning rate
    learning_rate = 0.00005 # 0.0002

    Z = tf.placeholder(tf.float32, [batch_size, Z_dim])
    X = tf.placeholder(tf.float32, [batch_size, 784])

    with tf.variable_scope('discriminator'):
        D_w_1 = tf.get_variable('D_w_1', initializer=tf.random_normal([5, 5, 1, 64], stddev=0.02))
        D_w_2 = tf.get_variable('D_w_2', initializer=tf.random_normal([5, 5, 64, 128], stddev=0.02))
        D_w_fc_1 = tf.get_variable('D_w_fc_1', initializer=tf.random_normal([7*7*128, 1], stddev=0.02)) 

    D_var_list = [D_w_1, D_w_2, D_w_fc_1]


    with tf.variable_scope('generator'):
        G_w_fc_1 = tf.get_variable('G_w_fc_1', initializer=tf.random_normal([Z_dim, 128*7*7], stddev=0.02))
        G_w_deconv_1 = tf.get_variable('G_w_deconv_1', initializer=tf.random_normal([5, 5, 64, 128], stddev=0.02))
        G_w_deconv_2 = tf.get_variable('G_w_deconv_2', initializer=tf.random_normal([5, 5, 1, 64], stddev=0.02))

    G_var_list = [G_w_fc_1, G_w_deconv_1, G_w_deconv_2]


    G_sample = generator(Z)
    D_logit_real = discriminator(X)
    D_logit_fake = discriminator(G_sample)


    # objective functions
    # discriminator aims at maximizing the probability of TRUE data (i.e. from the dataset) and minimizing the probability
    # of GENERATED/FAKE data:
    D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
    D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
    D_loss = D_loss_real + D_loss_fake

    # generator aims at maximizing the probability of GENERATED/FAKE data (i.e. fool the discriminator)
    G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))

    D_solver = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(D_loss, var_list=D_var_list)
    # when optimizing generator, discriminator is kept fixed
    G_solver = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(G_loss, var_list=G_var_list)


    with tf.Session() as sess:    

        sess.run(tf.global_variables_initializer())

        if not os.path.exists('out/'):
            os.makedirs('out/')

        for i_epoch in range(nb_epochs):

            G_loss_val = 0
            D_loss_val = 0

            for i_batch in range(batches_per_epoch):
                print('batch %i/%i' % (i_batch+1, batches_per_epoch))

                X_mb, _ = mnist.train.next_batch(batch_size)

                # train discriminator
                _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(batch_size, Z_dim)})
                D_loss_val += D_loss_curr

                # train generator
                _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(batch_size, Z_dim)})
                G_loss_val += G_loss_curr

                if i_batch % 50 == 0:
                    samples = sess.run(G_sample, feed_dict={Z: sample_Z(batch_size, Z_dim)})  #EDIT: changed to batch_size to match the tensor

                    fig = plot(samples)
                    plt.savefig('out/%i_%i.png' % (i_epoch, i_batch), bbox_inches='tight')
                    plt.close(fig)

            print('Iter: {}'.format(i_epoch))
            print('D loss: {:.4}'.format(D_loss_curr)) #EDIT: You were trying to print the tensor.
            print('G_loss: {:.4}'.format(G_loss_curr))#EDIT: You were trying to print the tensor.
like image 26
The Puternerd Avatar answered Nov 16 '22 21:11

The Puternerd