Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use forward() method instead of model.generate() for T5 model

For my use case, I need to use the model.forward() instead of the model.generate() method i.e instead of the below code

outs = model.model.generate(input_ids=batch['source_ids'],
                                 attention_mask=batch['source_mask'],
                                 output_scores=True,
                                 max_length=model.model_arguments.max_output_seq_length)

preds_cleaned = [model.tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) for ids in outs]

I need to use

model_outputs = model.model(
            input_ids=batch["source_ids"],
            attention_mask=batch["source_mask"],
            labels=lm_labels.to(device),
            decoder_attention_mask=batch['target_mask']
        )
logits = model_outputs.logits
softmax_logits = m(logits)
max_logits = torch.max(softmax_logits, dim=2)

    

decoding these logits gives unprocessed text that has many issues like repetition of words at the end etc. What do I need to do to get the same result as model.generate() ?

like image 545
NRJ_Varshney Avatar asked Sep 01 '25 04:09

NRJ_Varshney


1 Answers

The two methods do something completely different.

Calling the model (which means the forward method) uses the labels for teacher forcing. This means inputs to the decoder are the labels shifted by one (see documentation). With teacher forcing, the decoder always gets the ground-truth token in the next step, no matter what the prediction was. Teacher forcing is used from model training, all steps are fully differentiable.

When you call the generate method, the model is used in the autoregressive fashion. Any token it generates is put as the input in the next step. However, selecting the token is a "hard" decision, and the gradient cannot be propagated through this decision. The generate method cannot be used for training. The output is coherent because the decoder reacts to what was previously generated.

With teacher forcing, the model might want to prefer generating a token and continue consistently with the generated token. However, it cannot continue consistently, because it is forced to continue as if it generated the token that actually is in the labels argument. This why you observe the incoherent output (which was nevertheless never intended to be output but only to be used for training).

like image 71
Jindřich Avatar answered Sep 03 '25 11:09

Jindřich