Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Where is perplexity calculated in the Huggingface gpt2 language model code?

I see some github comments saying the output of the model() call's loss is in the form of perplexity: https://github.com/huggingface/transformers/issues/473

But when I look at the relevant code... https://huggingface.co/transformers/_modules/transformers/modeling_openai.html#OpenAIGPTLMHeadModel.forward

    if labels is not None:
        # Shift so that tokens < n predict n
        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        loss_fct = CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        outputs = (loss,) + outputs

    return outputs  # (loss), lm_logits, (all hidden states), (all attentions)

I see cross entropy being calculated, but no transformation into perplexity. Where does the loss finally get transformed? Or is there a transformation already there that I'm not understanding?

like image 319
user947659 Avatar asked Mar 24 '20 13:03

user947659


1 Answers

Ah ok, I found the answer. The code is actually returning cross entropy. In the github comment where they say it is perplexity...they are saying that because the OP does

return math.exp(loss)

which transforms entropy to perplexity :)

like image 82
user947659 Avatar answered Nov 29 '22 13:11

user947659