Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use the past with HuggingFace Transformers GPT-2?

I have:

        context = torch.tensor(context, dtype=torch.long, device=self.device)
        context = context.unsqueeze(0)
        generated = context
        with torch.no_grad():
            past_outputs = None
            for i in trange(num_words):
                print(i, num_words)
                inputs = {"input_ids": generated}

                outputs, past_outputs = self.model(
                    **inputs,
                    past=past_outputs
                )
                next_token_logits = outputs[
                    0, -1, :] / (temperature if temperature > 0 else 1.0)

                # reptition penalty from CTRL
                # (https://arxiv.org/abs/1909.05858)
                for _ in set(generated.view(-1).tolist()):
                    next_token_logits[_] /= repetition_penalty

                filtered_logits = top_k_top_p_filtering(
                    next_token_logits, top_k=top_k, top_p=top_p)
                if temperature == 0:  # greedy sampling:
                    next_token = torch.argmax(filtered_logits).unsqueeze(0)
                else:
                    next_token = torch.multinomial(
                        F.softmax(filtered_logits, dim=-1), num_samples=1)

                generated = torch.cat(
                    (generated, next_token.unsqueeze(0)), dim=1)

This works for the first iteration, but then I get an error for the next iteration:

  File "/Users/shamoon/Sites/wordblot/packages/ml-server/generator.py", line 143, in sample_sequence
    past=past_outputs
  File "/Users/shamoon/.local/share/virtualenvs/ml-server-EdimT5-E/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/shamoon/.local/share/virtualenvs/ml-server-EdimT5-E/lib/python3.7/site-packages/transformers/modeling_gpt2.py", line 601, in forward
    output_hidden_states=output_hidden_states,
  File "/Users/shamoon/.local/share/virtualenvs/ml-server-EdimT5-E/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/shamoon/.local/share/virtualenvs/ml-server-EdimT5-E/lib/python3.7/site-packages/transformers/modeling_gpt2.py", line 470, in forward
    position_embeds = self.wpe(position_ids)
  File "/Users/shamoon/.local/share/virtualenvs/ml-server-EdimT5-E/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/shamoon/.local/share/virtualenvs/ml-server-EdimT5-E/lib/python3.7/site-packages/torch/nn/modules/sparse.py", line 114, in forward
    self.norm_type, self.scale_grad_by_freq, self.sparse)
  File "/Users/shamoon/.local/share/virtualenvs/ml-server-EdimT5-E/lib/python3.7/site-packages/torch/nn/functional.py", line 1724, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
IndexError: index out of range in self

Is there something I'm doing wrong?

like image 295
Shamoon Avatar asked Aug 03 '20 15:08

Shamoon


1 Answers

I believe the problem is that context contains integer values exceeding vocabulary size. My assumption is based on the last traceback line:

return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
IndexError: index out of range in self
like image 70
roman Avatar answered Oct 24 '22 10:10

roman