I've just trained an LSTM language model using pytorch. The main body of the class is this:
class LM(nn.Module):
def __init__(self, n_vocab,
seq_size,
embedding_size,
lstm_size,
pretrained_embed):
super(LM, self).__init__()
self.seq_size = seq_size
self.lstm_size = lstm_size
self.embedding = nn.Embedding.from_pretrained(pretrained_embed, freeze = True)
self.lstm = nn.LSTM(embedding_size,
lstm_size,
batch_first=True)
self.fc = nn.Linear(lstm_size, n_vocab)
def forward(self, x, prev_state):
embed = self.embedding(x)
output, state = self.lstm(embed, prev_state)
logits = self.fc(output)
return logits, state
Now I want to write a function which calculates how good a sentence is, based on the trained language model (some score like perplexity, etc.).
I'm a bit confused and I don't know how should I calculate this.
A similar sample would be of greate use.
Metric Description This implementation of perplexity is calculated with log base e , as in perplexity = e**(sum(losses) / num_tokenized_tokens) , following recent convention in deep learning frameworks.
Trying to understand the relationship between cross-entropy and perplexity. In general for a model M, Perplexity(M)=2^entropy(M) .
log2(1) + np. log2(0.5) + np. log2(1))/3 = -0.3333 np. power(2, -l) = 1.
When using Cross-Entropy loss you just use the exponential function torch.exp()
calculate perplexity from your loss.
(pytorch cross-entropy also uses the exponential function resp. log_n)
So here is just some dummy example:
import torch
import torch.nn.functional as F
num_classes = 10
batch_size = 1
# your model outputs / logits
output = torch.rand(batch_size, num_classes)
# your targets
target = torch.randint(num_classes, (batch_size,))
# getting loss using cross entropy
loss = F.cross_entropy(output, target)
# calculating perplexity
perplexity = torch.exp(loss)
print('Loss:', loss, 'PP:', perplexity)
In my case the output is:
Loss: tensor(2.7935) PP: tensor(16.3376)
You just need to be beware of that if you want to get the per-word-perplexity you need to have per word loss as well.
Here is a neat example for a language model that might be interesting to look at that also computes the perplexity from the output:
https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-intermediate/language_model/main.py#L30-L50
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With