I've been trying to solve this problem for weeks now and I'm at a total loss. My model is based mostly from the Transformer model for language learning Tensorflow tutorial and takes inspiration from this this paper which uses a Transformer model for image captioning.
My goal is to teach a model to caption images from my own dataset. I've unit tested just about everything and it seems to be working perfect, but my only issue remaining is that the model very quickly learns to predict only the end token. All my captions in the dataset are structured like "<start> caption text here <end>" but the model quickly learns to predict "<end> <end> <end> <end> ... "
The loss and accuracy minimize and maximize respectively, the loss approaches zero and the accuracy gets as high as >90% in less than one epoch (which shouldn't be possible in an image captioning task), but all the model is predicting is the end token. I've tested the loss and accuracy functions of my own tensor filled with end tokens and it gives a high loss and 0% accuracy which is correct, so I have no idea how accuracy slowly climbs up and loss slowly falls when it's seems like it's just predicting end tokens.
Here is the beginning of the first epoch in one run. I print out the predictions (which is just the argmax along the vocab_size dimension) on the first example every 50 batches to get a sample of the output of the model and how it changes:
Epoch 1 Batch 0 Loss 9.0129 Accuracy 0.0000
Predictions: tf.Tensor([5817 6816 6626 2530 521 7248 3903 4040 2104 7952], shape=(10,), dtype=int64)
Epoch 1 Batch 50 Loss 9.0176 Accuracy 0.0001
Predictions: tf.Tensor([2904 3546 5805 3328 3021 3028 4687 2457 7491 6794], shape=(10,), dtype=int64)
Epoch 1 Batch 100 Loss 9.0173 Accuracy 0.0000
Predictions: tf.Tensor([5817 6589 2535 7221 4370 4191 6440 5486 4636 1857], shape=(10,), dtype=int64)
Epoch 1 Batch 150 Loss 9.0143 Accuracy 0.0000
Predictions: tf.Tensor([5817 6769 6709 955 6709 6284 6709 4732 1027 1027], shape=(10,), dtype=int64)
Epoch 1 Batch 200 Loss 9.0093 Accuracy 0.0000
Predictions: tf.Tensor([6337 2300 3304 6067 4284 33 6895 2457 6237 6125], shape=(10,), dtype=int64)
Epoch 1 Batch 250 Loss 9.0033 Accuracy 0.0001
Predictions: tf.Tensor([5817 5503 2889 554 5771 7612 196 1808 6237 6537], shape=(10,), dtype=int64)
Epoch 1 Batch 300 Loss 8.9953 Accuracy 0.0002
Predictions: tf.Tensor([6067 5 521 5587 3757 2457 3021 2305 6151 584], shape=(10,), dtype=int64)
Epoch 1 Batch 350 Loss 8.9855 Accuracy 0.0007
Predictions: tf.Tensor([5817 4133 5805 2484 7403 5084 3171 1042 4863 1705], shape=(10,), dtype=int64)
Epoch 1 Batch 400 Loss 8.9740 Accuracy 0.0019
Predictions: tf.Tensor([5817 1801 5719 1829 4284 4191 6895 6695 4658 4863], shape=(10,), dtype=int64)
Epoch 1 Batch 450 Loss 8.9607 Accuracy 0.0047
Predictions: tf.Tensor([4133 4284 3822 6895 4425 4663 3 2457 3 3604], shape=(10,), dtype=int64)
Epoch 1 Batch 500 Loss 8.9473 Accuracy 0.0090
Predictions: tf.Tensor([5216 7 3 521 3 3 3 3 3 3], shape=(10,), dtype=int64)
Epoch 1 Batch 550 Loss 8.9329 Accuracy 0.0140
Predictions: tf.Tensor([3 7 5 3 3 3 3 3 3 3], shape=(10,), dtype=int64)
Epoch 1 Batch 600 Loss 8.9183 Accuracy 0.0186
Predictions: tf.Tensor([3 3 7 5 3 3 3 3 3 3], shape=(10,), dtype=int64)
Epoch 1 Batch 650 Loss 8.9023 Accuracy 0.0227
Predictions: tf.Tensor([3 3 3 3 3 3 3 3 3 3], shape=(10,), dtype=int64)
Epoch 1 Batch 700 Loss 8.8860 Accuracy 0.0262
Predictions: tf.Tensor([3 3 3 3 3 3 3 3 3 3], shape=(10,), dtype=int64)
For reference, 3 means <end> in my tokenizer:
>>> roast_tokenizer.index_word[3]
<end>
What I've tried:
1. Adjusting the learning rate
I've used the normal Adam optimizer with learning rates from 1e-3 to 1e-7, I've tried it with various beta decay, and I implemented the lr schedule that was suggested by the Tensorflow tutorial that looks like this:
and also experimented with varying that schedule from a max of 1e-3 to 1e-7.
2. Smaller models:
My model has 2 encoders and 2 decoders with a latent dimensionality of 256. The feed forward networks in the encoders and decoders have 1024 nodes.
I've tried a model with only one encoder and only one decoder and I have the same effect.
3. Different losses:
I use sparse categorical crossentropy on the logits which is what the tensorflow tutorial suggests to use, but I've also tried the sequence_loss from the tensorflow_addons package, but I still get the same issue.
I'm running out of things I can test at this point. Lookahead masking and padding masking seem to work fine. Positional encoding seems to work fine. Self attention and cross attention seem to be working fine. It's possible I still have a mistake somewhere in my model, but I really have dove pretty deep into the theory on every aspect of Transformer models and everything seems perfect after looking at it for weeks, so I'm really stuck here.
I've looked at other stackoverflow posts about models predicting the end token only, but the only suggestions I've seen are to lower the learning rate or to let the model keep training, which I have done and the model reaches >90% accuracy and then plummets to <20% in the second epoch and then starts climbing slowly back up, but still seems to be predicting only end tokens so I have no idea what's happening there.
For additional reference, here are my loss and accuracy methods and some testing to show they work correctly:
def loss_function(real, pred):
mask = tf.math.logical_not(tf.math.equal(real, 0))
loss_ = loss_object(real, pred)
mask = tf.cast(mask, dtype=loss_.dtype)
loss_ *= mask
return tf.reduce_sum(loss_)/tf.reduce_sum(mask)
def accuracy_function(real, pred):
accuracies = tf.equal(real, tf.cast(tf.argmax(pred, axis=-1), dtype=real.dtype))
mask = tf.math.logical_not(tf.math.equal(real, 0))
accuracies = tf.math.logical_and(mask, accuracies)
accuracies = tf.cast(accuracies, dtype=tf.float32)
mask = tf.cast(mask, dtype=tf.float32)
return tf.reduce_sum(accuracies)/tf.reduce_sum(mask)
real_example = tf.convert_to_tensor([[79,80,50]])
logits = tf.one_hot([79,80,50], 8000) * 1
pred_example = tf.expand_dims(logits, 0)
accuracy_function(real_example, pred_example), loss_function(real_example, pred_example)
(<tf.Tensor: shape=(), dtype=float32, numpy=1.0>,
<tf.Tensor: shape=(), dtype=float32, numpy=7.987412>)
This shows that you will get 100% accuracy if you predict the correct tokens. I have that * 1 multiplier because I was playing around with increasing it to 100 or 1000 and seeing the loss continue to decrease (which makes sense because it's logits and not a probability distribution).
logits = tf.one_hot([3,3,3], 8000)
print(logits[0][:5])
pred_example = tf.expand_dims(logits, 0)
accuracy_function(real_example, pred_example), loss_function(real_example, pred_example)
tf.Tensor([0. 0. 0. 1. 0.], shape=(5,), dtype=float32)
(<tf.Tensor: shape=(), dtype=float32, numpy=0.0>,
<tf.Tensor: shape=(), dtype=float32, numpy=8.9874115>)
In this example, the real labels are kept the same, but the prediction is just end tokens. The accuracy ends up being 0% which is correct, and the loss is high.
Any suggestions would be greatly appreciated, thank you!
I found the answer and it's pretty simple actually. The model is underfitting the data. I originally decided to only use one encoder and one decoder in my transformer because I want to run it on mobile, but I switched to the full size which is 12 encoders and 8 decoders (and 12 attention heads) which created a much bigger model. Now, it still passes through the phase of only predicting the end token but then starts predicting the out-of-vocabulary token too and then suddenly it starts to experiment with more and more less frequent words.
Here is the beginning of epoch 1:
Epoch 1 Batch 0 Loss 9.6558 Accuracy 0.0000
Predictions: reunions union dissapointment onion shitting clap gus earned peabody eraser smiled senior trips cosmetics fuckin’ valderrama
Epoch 1 Batch 50 Loss 9.6610 Accuracy 0.0000
Predictions: googly tumbleweed convince himself cobra absorbing riot zack fuckin’ diarrhea posted videos cleanest od <end> <end>
Epoch 1 Batch 100 Loss 9.6378 Accuracy 0.0018
Predictions: tryin drilling butterfly understand <end> mother's horton overcompensating hogwarts session fuckin’ lookalike videos <end> <end> <end>
Epoch 1 Batch 150 Loss 9.6037 Accuracy 0.0136
Predictions: pumping <end> <end> peripheral <end> <end> <end> shitting <end> <end> <end> <end> <end> <end> <end> <end>
Epoch 1 Batch 200 Loss 9.5651 Accuracy 0.0271
Predictions: <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end>
Epoch 1 Batch 250 Loss 9.5163 Accuracy 0.0373
Predictions: <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end>
Epoch 1 Batch 300 Loss 9.4657 Accuracy 0.0438
Predictions: <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end>
Epoch 1 Batch 350 Loss 9.4148 Accuracy 0.0486
Predictions: <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end>
Epoch 1 Batch 400 Loss 9.3692 Accuracy 0.0518
Predictions: <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end>
...
Epoch 1 Batch 5650 Loss 5.9101 Accuracy 0.3303
Predictions: look <unk> <unk> ? and your face a ? <end> <end> <end> <end> <end> <end> <end>
Epoch 1 Batch 5700 Loss 5.8876 Accuracy 0.3330
Predictions: you look more you're than the <end> <end> <end> <end> <end> <end> <end> <end> <end> <end>
Epoch 1 Batch 5750 Loss 5.8664 Accuracy 0.3354
Predictions: the only <unk> is going to go ? your <end> not <end> your <end> <end> <end>
You can see it's much more sentence-like now. The model will still near 90% accuracy towards the end of the epoch which I can't comprehend, so if anyone has answers for that let me know. It just seems unbelievably high even though it starts to do pretty well.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With