For my use case, I need to use the model.forward() instead of the model.generate() method i.e instead of the below code
outs = model.model.generate(input_ids=batch['source_ids'],
                                 attention_mask=batch['source_mask'],
                                 output_scores=True,
                                 max_length=model.model_arguments.max_output_seq_length)
preds_cleaned = [model.tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) for ids in outs]
I need to use
model_outputs = model.model(
            input_ids=batch["source_ids"],
            attention_mask=batch["source_mask"],
            labels=lm_labels.to(device),
            decoder_attention_mask=batch['target_mask']
        )
logits = model_outputs.logits
softmax_logits = m(logits)
max_logits = torch.max(softmax_logits, dim=2)
    
decoding these logits gives unprocessed text that has many issues like repetition of words at the end etc. What do I need to do to get the same result as model.generate() ?
The two methods do something completely different.
Calling the model (which means the forward method) uses the labels for teacher forcing. This means inputs to the decoder are the labels shifted by one (see documentation). With teacher forcing, the decoder always gets the ground-truth token in the next step, no matter what the prediction was. Teacher forcing is used from model training, all steps are fully differentiable.
When you call the generate method, the model is used in the autoregressive fashion. Any token it generates is put as the input in the next step. However, selecting the token is a "hard" decision, and the gradient cannot be propagated through this decision. The generate method cannot be used for training. The output is coherent because the decoder reacts to what was previously generated.
With teacher forcing, the model might want to prefer generating a token and continue consistently with the generated token. However, it cannot continue consistently, because it is forced to continue as if it generated the token that actually is in the labels argument. This why you observe the incoherent output (which was nevertheless never intended to be output but only to be used for training).
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With