Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to predict the probability of an empty string using BERT

Suppose we have a template sentence like this:

  • "The ____ house is our meeting place."

and we have a list of adjectives to fill in the blank, e.g.:

  • "yellow"
  • "large"
  • ""

Note that one of these is an empty string.

The goal is to compare the probabilities to select the most likely word to describe "house" given the context of the sentence. If it's more likely to have nothing, this should also be taken into consideration.

We can predict the probability of each word filling in the blank, but how would we predict that the likelihood of there being no adjective to describe "house"?

To predict the probability of a word:

from transformers import BertTokenizer, BertForMaskedLM
import torch
from torch.nn import functional as F

# Load BERT tokenizer and pre-trained model
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
model = BertForMaskedLM.from_pretrained('bert-large-uncased', return_dict=True)

targets = ["yellow", "large"]
sentence = "The [MASK] house is our meeting place."

# Using BERT, compute probability over its entire vocabulary, returning logits
input = tokenizer.encode_plus(sentence, return_tensors = "pt") 
mask_index = torch.where(input["input_ids"][0] == tokenizer.mask_token_id)[0] 
with torch.no_grad():
    output = model(**input) 

# Run softmax over the logits to get the probabilities
softmax = F.softmax(output.logits[0], dim=-1)

# Find the words' probabilities in this probability distribution
target_probabilities = {t: softmax[mask_index, tokenizer.vocab[t]].numpy()[0] for t in targets}
target_probabilities

This outputs a list of the words and their associated probabilities:

{'yellow': 0.0061520976, 'large': 0.00071377633}

If I try to add an empty string to the list, I get the following error:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-62-6f726220a108> in <module>
     18 
     19 # Find the words' probabilities in this probability distribution
---> 20 target_probabilities = {t: softmax[mask_index, tokenizer.vocab[t]].numpy()[0] for t in targets}
     21 target_probabilities

<ipython-input-62-6f726220a108> in <dictcomp>(.0)
     18 
     19 # Find the words' probabilities in this probability distribution
---> 20 target_probabilities = {t: softmax[mask_index, tokenizer.vocab[t]].numpy()[0] for t in targets}
     21 target_probabilities

KeyError: ''

This is because BERT's vocabulary contains no empty string, so we can't look up the probability of something that doesn't exist in the model.

How should we get the probability of there being no word to fill in the blank? Is this possible with the model? Does it make sense to use the empty token [PAD] instead of an empty string? (I've only seen [PAD] used at the end of sentences, to make a group of sentences the same length.)

like image 505
brienna Avatar asked Dec 27 '21 23:12

brienna


1 Answers

One solution to this problem is to compare the sentence scores by adding the log-softmax of each token.

First, I should say that logits scores in BERT aren't really probabilities when you use softmax on them. But it seems this is an acceptable approach. So, I will use that as well.

Second, you should also consider cases that the adjective has several tokens. My solution also solves issues with multiple tokens.

Here is the code fix:

targets = ["", "yellow", "large", "very large"]
target_log_P = {t: None for t in targets}
for target in target_log_P:
     input = tokenizer.encode_plus(sentence.replace("[MASK]", target), return_tensors = "pt")
     output = model(**input)
     target_log_P[target] = sum([
         torch.log(F.softmax(output.logits[0][i], dim=-1)[idx])
         for i, idx in enumerate(input['input_ids'][0])
     ]).item()

Maybe there is a pipeline for this and my solution here is not a standard way, but it seems working...

Here are the results:

>>> target_log_P
{'': -37.5234375, 'yellow': -37.08171463012695, 'large': -35.85972213745117, 'very large': -46.483154296875}
like image 192
Mehdi Avatar answered Sep 19 '22 10:09

Mehdi