Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Add class information to keras network

I am trying to figure out how I will use the label information of my dataset with Generative Adversarial Networks. I am trying to use the following implementation of conditional GANs that can be found here. My dataset contains two different image domains (real objects and sketches) with common class information (chair, tree, orange etc). I opted for this implementation which only considers the two different domains as different "classes" for the correspondence (train samples X correspond to the real images while target samples y correspond to the sketch images).

Is there a way to modify my code and take into account the class information (chair, tree, etc.) in my whole architecture? I want actually my discriminator to predict whether or not my generated images from the generator belong to a specific class and not only whether they are real or not. As it is, with the current architecture, the system learns to create similar sketches in all cases.

Update: The discriminator returns a tensor of size 1x7x7 then both y_true and y_pred are passed through a flatten layer before calculating the loss:

def discriminator_loss(y_true, y_pred):
     BATCH_SIZE=100
     return K.mean(K.binary_crossentropy(K.flatten(y_pred), K.concatenate([K.ones_like(K.flatten(y_pred[:BATCH_SIZE,:,:,:])),K.zeros_like(K.flatten(y_pred[:BATCH_SIZE,:,:,:])) ]) ), axis=-1)

and the loss function of the discriminator over the generator:

def discriminator_on_generator_loss(y_true,y_pred):
     BATCH_SIZE=100
     return K.mean(K.binary_crossentropy(K.flatten(y_pred), K.ones_like(K.flatten(y_pred))), axis=-1)

Furthremore, my modification of the discriminator model for output 1 layer:

model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
#model.add(Activation('sigmoid'))

Now the discriminator outputs 1 layer. How can I modify the above-mentioned loss functions correspondingly? Should I have 7 instead of 1, for the n_classes = 6 + one class for predicting real and fake pairs?

like image 549
Jose Ramon Avatar asked Jun 18 '18 11:06

Jose Ramon


2 Answers

Suggested Solution

Reusing the code from the repository you shared, here are some suggested modifications to train a classifier along your generator and discriminator (their architectures and other losses are left untouched):

from keras import backend as K
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D

def lenet_classifier_model(nb_classes):
    # Snipped by Fabien Tanc - https://www.kaggle.com/ftence/keras-cnn-inspired-by-lenet-5
    # Replace with your favorite classifier...
    model = Sequential()
    model.add(Convolution2D(12, 5, 5, activation='relu', input_shape=in_shape, init='he_normal'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Convolution2D(25, 5, 5, activation='relu', init='he_normal'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(180, activation='relu', init='he_normal'))
    model.add(Dropout(0.5))
    model.add(Dense(100, activation='relu', init='he_normal'))
    model.add(Dropout(0.5))
    model.add(Dense(nb_classes, activation='softmax', init='he_normal'))

def generator_containing_discriminator_and_classifier(generator, discriminator, classifier):
    inputs = Input((IN_CH, img_cols, img_rows))
    x_generator = generator(inputs)

    merged = merge([inputs, x_generator], mode='concat', concat_axis=1)
    discriminator.trainable = False
    x_discriminator = discriminator(merged)

    classifier.trainable = False
    x_classifier = classifier(x_generator)

    model = Model(input=inputs, output=[x_generator, x_discriminator, x_classifier])

    return model


def train(BATCH_SIZE):
    (X_train, Y_train, LABEL_train) = get_data('train')  # replace with your data here
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
    Y_train = (Y_train.astype(np.float32) - 127.5) / 127.5
    discriminator = discriminator_model()
    generator = generator_model()
    classifier = lenet_classifier_model(6)
    generator.summary()
    discriminator_and_classifier_on_generator = generator_containing_discriminator_and_classifier(
        generator, discriminator, classifier)
    d_optim = Adagrad(lr=0.005)
    g_optim = Adagrad(lr=0.005)
    generator.compile(loss='mse', optimizer="rmsprop")
    discriminator_and_classifier_on_generator.compile(
        loss=[generator_l1_loss, discriminator_on_generator_loss, "categorical_crossentropy"],
        optimizer="rmsprop")
    discriminator.trainable = True
    discriminator.compile(loss=discriminator_loss, optimizer="rmsprop")
    classifier.trainable = True
    classifier.compile(loss="categorical_crossentropy", optimizer="rmsprop")

    for epoch in range(100):
        print("Epoch is", epoch)
        print("Number of batches", int(X_train.shape[0] / BATCH_SIZE))
        for index in range(int(X_train.shape[0] / BATCH_SIZE)):
            image_batch = Y_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]
            label_batch = LABEL_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]  # replace with your data here

            generated_images = generator.predict(X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE])
            if index % 20 == 0:
                image = combine_images(generated_images)
                image = image * 127.5 + 127.5
                image = np.swapaxes(image, 0, 2)
                cv2.imwrite(str(epoch) + "_" + str(index) + ".png", image)
                # Image.fromarray(image.astype(np.uint8)).save(str(epoch)+"_"+str(index)+".png")

            # Training D:
            real_pairs = np.concatenate((X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], image_batch),
                                        axis=1)
            fake_pairs = np.concatenate(
                (X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], generated_images), axis=1)
            X = np.concatenate((real_pairs, fake_pairs))
            y = np.zeros((20, 1, 64, 64))  # [1] * BATCH_SIZE + [0] * BATCH_SIZE
            d_loss = discriminator.train_on_batch(X, y)
            print("batch %d d_loss : %f" % (index, d_loss))
            discriminator.trainable = False

            # Training C:
            c_loss = classifier.train_on_batch(image_batch, label_batch)
            print("batch %d c_loss : %f" % (index, c_loss))
            classifier.trainable = False

            # Train G:
            g_loss = discriminator_and_classifier_on_generator.train_on_batch(
                X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], 
                [image_batch, np.ones((10, 1, 64, 64)), label_batch])
            discriminator.trainable = True
            classifier.trainable = True
            print("batch %d g_loss : %f" % (index, g_loss[1]))
            if index % 20 == 0:
                generator.save_weights('generator', True)
                discriminator.save_weights('discriminator', True)

Theoretical Details

I believe there are some misunderstandings regarding how conditional GANs work and what is the discriminators role in such schemes.

Role of the Discriminator

In the min-max game which is GAN training [4], the discriminator D is playing against the generator G (the network you actually care about) so that under D's scrutiny, G becomes better at outputting realistic results.

For that, D is trained to tell apart real samples from samples from G ; while G is trained to fool D by generating realistic results / results following the target distribution.

Note: in the case of conditional GANs, i.e. GANs mapping an input sample from one domain A (e.g. real picture) to another domain B (e.g. sketch), D is usually fed with the pairs of samples stacked together and has to discriminate "real" pairs (input sample from A + corresponding target sample from B) and "fake" pairs (input sample from A + corresponding output from G) [1, 2]

Training a conditional generator against D (as opposed to simply training G alone, with a L1/L2 loss only e.g. DAE) improves the sampling capability of G, forcing it to output crisp, realistic results instead of trying to average the distribution.

Even though discriminators can have multiple sub-networks to cover other tasks (see next paragraphs), D should keep at least one sub-network/output to cover its main task: telling real samples from generated ones apart. Asking D to regress further semantic information (e.g. classes) alongside may interfere with this main purpose.

Note: D output is often not a simple scalar / boolean. It is common to have a discriminator (e.g. PatchGAN [1, 2]) returning a matrix of probabilities, evaluating how realistic patches made from its input are.


Conditional GANs

Traditional GANs are trained in an unsupervised manner to generate realistic data (e.g. images) from a random noise vector as input. [4]

As previously mentioned, conditional GANs have further input conditions. Along/instead of the noise vector, they take for input a sample from a domain A and return a corresponding sample from a domain B. A can be a completely different modality, e.g. B = sketch image while A = discrete label ; B = volumetric data while A = RGB image, etc. [3]

Such GANs can also be conditioned by multiples inputs, e.g. A = real image + discrete label while B = sketch image. A famous work introducing such methods is InfoGAN [5]. It presents how to condition GANs on multiple continuous or discrete inputs (e.g. A = digit class + writing type, B = handwritten digit image), using a more advanced discriminator which has for 2nd task to force G to maximize the mutual-information between its conditioning inputs and its corresponding outputs.


Maximizing the Mutual Information for cGANs

InfoGAN discriminator has 2 heads/sub-networks to cover its 2 tasks [5]:

  • One head D1 does the traditional real/generated discrimination -- G has to minimize this result, i.e. it has to fool D1 so that it can't tell apart real form generated data;
  • Another head D2 (also named Q network) tries to regress the input A information -- G has to maximize this result, i.e. it has to output data which "show" the requested semantic information (c.f. mutual-information maximization between G conditional inputs and its outputs).

You can find a Keras implementation here for instance: https://github.com/eriklindernoren/Keras-GAN/tree/master/infogan.

Several works are using similar schemes to improve control over what a GAN is generating, by using provided labels and maximizing the mutual information between these inputs and G outputs [6, 7]. The basic idea is always the same though:

  • Train G to generate elements of domain B, given some inputs of domain(s) A;
  • Train D to discriminate "real"/"fake" results -- G has to minimize this;
  • Train Q (e.g. a classifier ; can share layers with D) to estimate the original A inputs from B samples -- G has to maximize this).

Wrapping Up

In your case, it seems you have the following training data:

  • real images Ia
  • corresponding sketch images Ib
  • corresponding class labels c

And you want to train a generator G so that given an image Ia and its class label c, it outputs a proper sketch image Ib'.

All in all, that's a lot of information you have, and you can supervise your training both on the conditioned images and the conditioned labels... Inspired from the aforementioned methods [1, 2, 5, 6, 7], here is a possible way of using all this information to train your conditional G:

Network G:
  • Inputs: Ia + c
  • Output: Ib'
  • Architecture: up-to-you (e.g. U-Net, ResNet, ...)
  • Losses: L1/L2 loss between Ib' & Ib, -D loss, Q loss
Network D:
  • Inputs: Ia + Ib (real pair), Ia + Ib' (fake pair)
  • Output: "fakeness" scalar/matrix
  • Architecture: up-to-you (e.g. PatchGAN)
  • Loss: cross-entropy on the "fakeness" estimation
Network Q:
  • Inputs: Ib (real sample, for training Q), Ib' (fake sample, when back-propagating through G)
  • Output: c' (estimated class)
  • Architecture: up-to-you (e.g. LeNet, ResNet, VGG, ...)
  • Loss: cross-entropy between c and c'
Training Phase:
  1. Train D on a batch of real pairs Ia + Ib then on a batch of fake pairs Ia + Ib';
  2. Train Q on a batch of real samples Ib;
  3. Fix D and Q weights;
  4. Train G, passing its generated outputs Ib' to D and Q to back-propagate through them.

Note: this is a really rough architecture description. I'd recommend going through the literature ([1, 5, 6, 7] as a good start) to get more details and maybe a more elaborate solution.


References

  1. Isola, Phillip, et al. "Image-to-image translation with conditional adversarial networks." arXiv preprint (2017). http://openaccess.thecvf.com/content_cvpr_2017/papers/Isola_Image-To-Image_Translation_With_CVPR_2017_paper.pdf
  2. Zhu, Jun-Yan, et al. "Unpaired image-to-image translation using cycle-consistent adversarial networks." arXiv preprint arXiv:1703.10593 (2017). http://openaccess.thecvf.com/content_ICCV_2017/papers/Zhu_Unpaired_Image-To-Image_Translation_ICCV_2017_paper.pdf
  3. Mirza, Mehdi, and Simon Osindero. "Conditional generative adversarial nets." arXiv preprint arXiv:1411.1784 (2014). https://arxiv.org/pdf/1411.1784
  4. Goodfellow, Ian, et al. "Generative adversarial nets." Advances in neural information processing systems. 2014. http://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf
  5. Chen, Xi, et al. "Infogan: Interpretable representation learning by information maximizing generative adversarial nets." Advances in Neural Information Processing Systems. 2016. http://papers.nips.cc/paper/6399-infogan-interpretable-representation-learning-by-information-maximizing-generative-adversarial-nets.pdf
  6. Lee, Minhyeok, and Junhee Seok. "Controllable Generative Adversarial Network." arXiv preprint arXiv:1708.00598 (2017). https://arxiv.org/pdf/1708.00598.pdf
  7. Odena, Augustus, Christopher Olah, and Jonathon Shlens. "Conditional image synthesis with auxiliary classifier gans." arXiv preprint arXiv:1610.09585 (2016). http://proceedings.mlr.press/v70/odena17a/odena17a.pdf
like image 178
benjaminplanche Avatar answered Sep 21 '22 19:09

benjaminplanche


You should modify your discriminator model, either to have two outputs, or to have a "n_classes + 1" output.

Warning: I don't see in the definition of your discriminator it outputting 'true/false', I see it outputting an image...

Somewhere it should contain a GlobalMaxPooling2D or an GlobalAveragePooling2D.
At the end and one or more Dense layers for classification.

If telling true/false, the last Dense should have 1 unit.
Otherwise n_classes + 1 units.

So, the ending of your discriminator should be something like

...GlobalMaxPooling2D()...
...Dense(someHidden,...)...
...Dense(n_classes+1,...)...

The discriminator will now output n_classes plus either a "true/fake" sign (you will not be able to use "categorical" there) or even a "fake class" (then you zero the other classes and use categorical)

Your generates sketches should be passes to the discriminator along with a target that will be the concatenation of the fake class with the other class.

Option 1 - Using the "true/fake" sign. (Don't use "categorical_crossentropy")

#true sketches into discriminator:
fakeClass = np.zeros((total_samples,))
sketchClass = originalClasses

targetClassTrue = np.concatenate([fakeClass,sketchClass], axis=-1)

#fake sketches into discriminator:
fakeClass = np.ones((total_fake_sketches))
sketchClass = originalClasses

targetClassFake = np.concatenate([fakeClass,sketchClass], axis=-1)

Option 2 - Using the "fake class" (can use "categorical_crossentropy"):

#true sketches into discriminator:
fakeClass = np.zeros((total_samples,))
sketchClass = originalClasses

targetClassTrue = np.concatenate([fakeClass,sketchClass], axis=-1)

#fake sketches into discriminator:
fakeClass = np.ones((total_fake_sketches))
sketchClass = np.zeros((total_fake_sketches, n_classes))

targetClassFake = np.concatenate([fakeClass,sketchClass], axis=-1)

Now concatenate everything into a single target array (respective to the input sketches)

Updated training method

For this training method, your loss function should be one of:

  • discriminator.compile(loss='binary_crossentropy', optimizer=....)
  • discriminator.compile(loss='categorical_crossentropy', optimizer=...)

Code:

for epoch in range(100):
    print("Epoch is", epoch)
    print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))

    for index in range(int(X_train.shape[0]/BATCH_SIZE)):

        #names:
            #images -> initial images, not changed    
            #sketches -> generated + true sketches    
            #classes -> your classification for the images    
            #isGenerated -> the output of your discriminator telling whether the passed sketches are fake

        batchSlice = slice(index*BATCH_SIZE,(index+1)*BATCH_SIZE)
        trueImages = X_train[batchSlice]

        trueSketches = Y_train[batchSlice] 
        trueClasses = originalClasses[batchSlice]
        trueIsGenerated = np.zeros((len(trueImages),)) #discriminator telling whether the sketch is fake or true (generated images = 1)
        trueEndTargets = np.concatenate([trueIsGenerated,trueClasses],axis=1)

        fakeSketches = generator.predict(trueImages)
        fakeClasses = originalClasses[batchSlize]             #if option 1 -> telling class + isGenerated - use "binary_crossentropy"
        fakeClasses = np.zeros((len(fakeSketches),n_classes)) #if option 2 -> telling if generated is an individual class - use "categorical_crossentropy"    
        fakeIsGenerated = np.ones((len(fakeSketches),))
        fakeEndTargets = np.concatenate([fakeIsGenerated, fakeClasses], axis=1)

        allSketches = np.concatenate([trueSketches,fakeSketches],axis=0)            
        allEndTargets = np.concatenate([trueEndTargets,fakeEndTargets],axis=0)

        d_loss = discriminator.train_on_batch(allSketches, allEndTargets)

        pred_temp = discriminator.predict(allSketches)
        #print(np.shape(pred_temp))
        print("batch %d d_loss : %f" % (index, d_loss))

        ##WARNING## In previous keras versions, "trainable" only takes effect if you compile the models. 
            #you should have the "discriminator" and the "discriminator_on_generator" with these set at the creation of the models and never change it again   

        discriminator.trainable = False
        g_loss = discriminator_on_generator.train_on_batch(trueImages, trueEndTargets)
        discriminator.trainable = True


        print("batch %d g_loss : %f" % (index, g_loss[1]))
        if index % 20 == 0:
            generator.save_weights('generator', True)
            discriminator.save_weights('discriminator', True)

Compiling the models properly

When you create "discriminator" and "discriminator_on_generator":

discriminator.trainable = True
for l in discriminator.layers:
    l.trainable = True


discriminator.compile(.....)

for l in discriminator_on_generator.layer[firstDiscriminatorLayer:]:
    l.trainable = False

discriminator_on_generator.compile(....)
like image 36
Daniel Möller Avatar answered Sep 18 '22 19:09

Daniel Möller