Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Restrict Vocab for BERT Encoder-Decoder Text Generation

Is there any way to restrict the vocabulary of the decoder in a Huggingface BERT encoder-decoder model? I'd like to force the decoder to choose from a small vocabulary when generating text rather than BERT's entire ~30k vocabulary.

like image 919
Joseph Harvey Avatar asked Nov 07 '25 01:11

Joseph Harvey


1 Answers

The generate method has a bad_words_ids attribute where you can provide a list of token IDs that you don't want to have in the output.

If you want to just decrease the probabilities of some tokens being generated, you can try to manipulate the bias parameter in the output layer of the model. If your model is based on BERT, you will find the last layer the BertLMPredictionHead of the decoder. Assuming your seq2seq model is in variable model, you can access the bias via model.decoder.cls.predictions.decoder.bias and decrease the bias for token IDs that you would like to appear with a lower probability.

Note also that if you initialize the seq2seq model with BERT (as in the example in the Huggingface Transformer documentation), you need to fine-tune the model heavily because the cross-attention is initialized randomly and the decoder part need to adapt for left-right generation.

like image 106
Jindřich Avatar answered Nov 10 '25 14:11

Jindřich



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!