I am trying to figure out how I will use the label information of my dataset with Generative Adversarial Networks. I am trying to use the following implementation of conditional GANs that can be found here. My dataset contains two different image domains (real objects and sketches) with common class information (chair, tree, orange etc). I opted for this implementation which only considers the two different domains as different "classes" for the correspondence (train samples X
correspond to the real images while target samples y
correspond to the sketch images).
Is there a way to modify my code and take into account the class information (chair, tree, etc.) in my whole architecture? I want actually my discriminator to predict whether or not my generated images from the generator belong to a specific class and not only whether they are real or not. As it is, with the current architecture, the system learns to create similar sketches in all cases.
Update: The discriminator returns a tensor of size 1x7x7
then both y_true
and y_pred
are passed through a flatten layer before calculating the loss:
def discriminator_loss(y_true, y_pred):
BATCH_SIZE=100
return K.mean(K.binary_crossentropy(K.flatten(y_pred), K.concatenate([K.ones_like(K.flatten(y_pred[:BATCH_SIZE,:,:,:])),K.zeros_like(K.flatten(y_pred[:BATCH_SIZE,:,:,:])) ]) ), axis=-1)
and the loss function of the discriminator over the generator:
def discriminator_on_generator_loss(y_true,y_pred):
BATCH_SIZE=100
return K.mean(K.binary_crossentropy(K.flatten(y_pred), K.ones_like(K.flatten(y_pred))), axis=-1)
Furthremore, my modification of the discriminator model for output 1 layer:
model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
#model.add(Activation('sigmoid'))
Now the discriminator outputs 1 layer. How can I modify the above-mentioned loss functions correspondingly? Should I have 7 instead of 1, for the n_classes = 6
+ one class for predicting real and fake pairs?
Reusing the code from the repository you shared, here are some suggested modifications to train a classifier along your generator and discriminator (their architectures and other losses are left untouched):
from keras import backend as K
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D
def lenet_classifier_model(nb_classes):
# Snipped by Fabien Tanc - https://www.kaggle.com/ftence/keras-cnn-inspired-by-lenet-5
# Replace with your favorite classifier...
model = Sequential()
model.add(Convolution2D(12, 5, 5, activation='relu', input_shape=in_shape, init='he_normal'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Convolution2D(25, 5, 5, activation='relu', init='he_normal'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(180, activation='relu', init='he_normal'))
model.add(Dropout(0.5))
model.add(Dense(100, activation='relu', init='he_normal'))
model.add(Dropout(0.5))
model.add(Dense(nb_classes, activation='softmax', init='he_normal'))
def generator_containing_discriminator_and_classifier(generator, discriminator, classifier):
inputs = Input((IN_CH, img_cols, img_rows))
x_generator = generator(inputs)
merged = merge([inputs, x_generator], mode='concat', concat_axis=1)
discriminator.trainable = False
x_discriminator = discriminator(merged)
classifier.trainable = False
x_classifier = classifier(x_generator)
model = Model(input=inputs, output=[x_generator, x_discriminator, x_classifier])
return model
def train(BATCH_SIZE):
(X_train, Y_train, LABEL_train) = get_data('train') # replace with your data here
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
Y_train = (Y_train.astype(np.float32) - 127.5) / 127.5
discriminator = discriminator_model()
generator = generator_model()
classifier = lenet_classifier_model(6)
generator.summary()
discriminator_and_classifier_on_generator = generator_containing_discriminator_and_classifier(
generator, discriminator, classifier)
d_optim = Adagrad(lr=0.005)
g_optim = Adagrad(lr=0.005)
generator.compile(loss='mse', optimizer="rmsprop")
discriminator_and_classifier_on_generator.compile(
loss=[generator_l1_loss, discriminator_on_generator_loss, "categorical_crossentropy"],
optimizer="rmsprop")
discriminator.trainable = True
discriminator.compile(loss=discriminator_loss, optimizer="rmsprop")
classifier.trainable = True
classifier.compile(loss="categorical_crossentropy", optimizer="rmsprop")
for epoch in range(100):
print("Epoch is", epoch)
print("Number of batches", int(X_train.shape[0] / BATCH_SIZE))
for index in range(int(X_train.shape[0] / BATCH_SIZE)):
image_batch = Y_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]
label_batch = LABEL_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE] # replace with your data here
generated_images = generator.predict(X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE])
if index % 20 == 0:
image = combine_images(generated_images)
image = image * 127.5 + 127.5
image = np.swapaxes(image, 0, 2)
cv2.imwrite(str(epoch) + "_" + str(index) + ".png", image)
# Image.fromarray(image.astype(np.uint8)).save(str(epoch)+"_"+str(index)+".png")
# Training D:
real_pairs = np.concatenate((X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], image_batch),
axis=1)
fake_pairs = np.concatenate(
(X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], generated_images), axis=1)
X = np.concatenate((real_pairs, fake_pairs))
y = np.zeros((20, 1, 64, 64)) # [1] * BATCH_SIZE + [0] * BATCH_SIZE
d_loss = discriminator.train_on_batch(X, y)
print("batch %d d_loss : %f" % (index, d_loss))
discriminator.trainable = False
# Training C:
c_loss = classifier.train_on_batch(image_batch, label_batch)
print("batch %d c_loss : %f" % (index, c_loss))
classifier.trainable = False
# Train G:
g_loss = discriminator_and_classifier_on_generator.train_on_batch(
X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :],
[image_batch, np.ones((10, 1, 64, 64)), label_batch])
discriminator.trainable = True
classifier.trainable = True
print("batch %d g_loss : %f" % (index, g_loss[1]))
if index % 20 == 0:
generator.save_weights('generator', True)
discriminator.save_weights('discriminator', True)
I believe there are some misunderstandings regarding how conditional GANs work and what is the discriminators role in such schemes.
In the min-max game which is GAN training [4], the discriminator D
is playing against the generator G
(the network you actually care about) so that under D
's scrutiny, G
becomes better at outputting realistic results.
For that, D
is trained to tell apart real samples from samples from G
; while G
is trained to fool D
by generating realistic results / results following the target distribution.
Note: in the case of conditional GANs, i.e. GANs mapping an input sample from one domain
A
(e.g. real picture) to another domainB
(e.g. sketch),D
is usually fed with the pairs of samples stacked together and has to discriminate "real" pairs (input sample fromA
+ corresponding target sample fromB
) and "fake" pairs (input sample fromA
+ corresponding output fromG
) [1, 2]
Training a conditional generator against D
(as opposed to simply training G
alone, with a L1/L2 loss only e.g. DAE) improves the sampling capability of G
, forcing it to output crisp, realistic results instead of trying to average the distribution.
Even though discriminators can have multiple sub-networks to cover other tasks (see next paragraphs), D
should keep at least one sub-network/output to cover its main task: telling real samples from generated ones apart. Asking D
to regress further semantic information (e.g. classes) alongside may interfere with this main purpose.
Note:
D
output is often not a simple scalar / boolean. It is common to have a discriminator (e.g. PatchGAN [1, 2]) returning a matrix of probabilities, evaluating how realistic patches made from its input are.
Traditional GANs are trained in an unsupervised manner to generate realistic data (e.g. images) from a random noise vector as input. [4]
As previously mentioned, conditional GANs have further input conditions. Along/instead of the noise vector, they take for input a sample from a domain A
and return a corresponding sample from a domain B
. A
can be a completely different modality, e.g. B = sketch image
while A = discrete label
; B = volumetric data
while A = RGB image
, etc. [3]
Such GANs can also be conditioned by multiples inputs, e.g. A = real image + discrete label
while B = sketch image
. A famous work introducing such methods is InfoGAN [5]. It presents how to condition GANs on multiple continuous or discrete inputs (e.g. A = digit class + writing type
, B = handwritten digit image
), using a more advanced discriminator which has for 2nd task to force G
to maximize the mutual-information between its conditioning inputs and its corresponding outputs.
InfoGAN discriminator has 2 heads/sub-networks to cover its 2 tasks [5]:
D1
does the traditional real/generated discrimination -- G
has to minimize this result, i.e. it has to fool D1
so that it can't tell apart real form generated data;D2
(also named Q
network) tries to regress the input A
information -- G
has to maximize this result, i.e. it has to output data which "show" the requested semantic information (c.f. mutual-information maximization between G
conditional inputs and its outputs).You can find a Keras implementation here for instance: https://github.com/eriklindernoren/Keras-GAN/tree/master/infogan.
Several works are using similar schemes to improve control over what a GAN is generating, by using provided labels and maximizing the mutual information between these inputs and G
outputs [6, 7]. The basic idea is always the same though:
G
to generate elements of domain B
, given some inputs of domain(s) A
;D
to discriminate "real"/"fake" results -- G
has to minimize this;Q
(e.g. a classifier ; can share layers with D
) to estimate the original A
inputs from B
samples -- G
has to maximize this).In your case, it seems you have the following training data:
Ia
Ib
c
And you want to train a generator G
so that given an image Ia
and its class label c
, it outputs a proper sketch image Ib'
.
All in all, that's a lot of information you have, and you can supervise your training both on the conditioned images and the conditioned labels...
Inspired from the aforementioned methods [1, 2, 5, 6, 7], here is a possible way of using all this information to train your conditional G
:
G
:
Ia
+ c
Ib'
Ib'
& Ib
, -D
loss, Q
lossD
:
Ia
+ Ib
(real pair), Ia
+ Ib'
(fake pair)Q
:
Ib
(real sample, for training Q
), Ib'
(fake sample, when back-propagating through G
)c'
(estimated class)c
and c'
D
on a batch of real pairs Ia
+ Ib
then on a batch of fake pairs Ia
+ Ib'
;Q
on a batch of real samples Ib
;D
and Q
weights;G
, passing its generated outputs Ib'
to D
and Q
to back-propagate through them.Note: this is a really rough architecture description. I'd recommend going through the literature ([1, 5, 6, 7] as a good start) to get more details and maybe a more elaborate solution.
You should modify your discriminator model, either to have two outputs, or to have a "n_classes + 1" output.
Warning: I don't see in the definition of your discriminator it outputting 'true/false', I see it outputting an image...
Somewhere it should contain a GlobalMaxPooling2D
or an GlobalAveragePooling2D
.
At the end and one or more Dense
layers for classification.
If telling true/false, the last Dense should have 1 unit.
Otherwise n_classes + 1
units.
So, the ending of your discriminator should be something like
...GlobalMaxPooling2D()...
...Dense(someHidden,...)...
...Dense(n_classes+1,...)...
The discriminator will now output n_classes
plus either a "true/fake" sign (you will not be able to use "categorical" there) or even a "fake class" (then you zero the other classes and use categorical)
Your generates sketches should be passes to the discriminator along with a target that will be the concatenation of the fake class with the other class.
Option 1 - Using the "true/fake" sign. (Don't use "categorical_crossentropy")
#true sketches into discriminator:
fakeClass = np.zeros((total_samples,))
sketchClass = originalClasses
targetClassTrue = np.concatenate([fakeClass,sketchClass], axis=-1)
#fake sketches into discriminator:
fakeClass = np.ones((total_fake_sketches))
sketchClass = originalClasses
targetClassFake = np.concatenate([fakeClass,sketchClass], axis=-1)
Option 2 - Using the "fake class" (can use "categorical_crossentropy"):
#true sketches into discriminator:
fakeClass = np.zeros((total_samples,))
sketchClass = originalClasses
targetClassTrue = np.concatenate([fakeClass,sketchClass], axis=-1)
#fake sketches into discriminator:
fakeClass = np.ones((total_fake_sketches))
sketchClass = np.zeros((total_fake_sketches, n_classes))
targetClassFake = np.concatenate([fakeClass,sketchClass], axis=-1)
Now concatenate everything into a single target array (respective to the input sketches)
For this training method, your loss function should be one of:
discriminator.compile(loss='binary_crossentropy', optimizer=....)
discriminator.compile(loss='categorical_crossentropy', optimizer=...)
Code:
for epoch in range(100):
print("Epoch is", epoch)
print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))
for index in range(int(X_train.shape[0]/BATCH_SIZE)):
#names:
#images -> initial images, not changed
#sketches -> generated + true sketches
#classes -> your classification for the images
#isGenerated -> the output of your discriminator telling whether the passed sketches are fake
batchSlice = slice(index*BATCH_SIZE,(index+1)*BATCH_SIZE)
trueImages = X_train[batchSlice]
trueSketches = Y_train[batchSlice]
trueClasses = originalClasses[batchSlice]
trueIsGenerated = np.zeros((len(trueImages),)) #discriminator telling whether the sketch is fake or true (generated images = 1)
trueEndTargets = np.concatenate([trueIsGenerated,trueClasses],axis=1)
fakeSketches = generator.predict(trueImages)
fakeClasses = originalClasses[batchSlize] #if option 1 -> telling class + isGenerated - use "binary_crossentropy"
fakeClasses = np.zeros((len(fakeSketches),n_classes)) #if option 2 -> telling if generated is an individual class - use "categorical_crossentropy"
fakeIsGenerated = np.ones((len(fakeSketches),))
fakeEndTargets = np.concatenate([fakeIsGenerated, fakeClasses], axis=1)
allSketches = np.concatenate([trueSketches,fakeSketches],axis=0)
allEndTargets = np.concatenate([trueEndTargets,fakeEndTargets],axis=0)
d_loss = discriminator.train_on_batch(allSketches, allEndTargets)
pred_temp = discriminator.predict(allSketches)
#print(np.shape(pred_temp))
print("batch %d d_loss : %f" % (index, d_loss))
##WARNING## In previous keras versions, "trainable" only takes effect if you compile the models.
#you should have the "discriminator" and the "discriminator_on_generator" with these set at the creation of the models and never change it again
discriminator.trainable = False
g_loss = discriminator_on_generator.train_on_batch(trueImages, trueEndTargets)
discriminator.trainable = True
print("batch %d g_loss : %f" % (index, g_loss[1]))
if index % 20 == 0:
generator.save_weights('generator', True)
discriminator.save_weights('discriminator', True)
When you create "discriminator" and "discriminator_on_generator":
discriminator.trainable = True
for l in discriminator.layers:
l.trainable = True
discriminator.compile(.....)
for l in discriminator_on_generator.layer[firstDiscriminatorLayer:]:
l.trainable = False
discriminator_on_generator.compile(....)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With