Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Torch JIT Trace = TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect

I am following this tutorial: https://huggingface.co/transformers/torchscript.html to create a trace of my custom BERT model, however when running the exact same dummy_input I receive an error:

TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. 
We cant record the data flow of Python values, so this value will be treated as a constant in the future. 

Having loaded in my model and tokenizer, the code to create the trace is the following:

text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = tokenizer.tokenize(text)

# Masking one of the input tokens
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]

tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors]

traced_model = torch.jit.trace(model, dummy_input)

The dummy_input is a list of tensors so I'm not sure where the Boolean type is coming into play here. Does anyone understand why this error is occurring and whether the Boolean conversion is happening?

Many Thanks

like image 642
gerardcslabs Avatar asked Dec 31 '22 15:12

gerardcslabs


1 Answers

What this error means

This warning occurs, when one tries to torch.jit.trace models which have data dependent control flow.

This simple example should make it clearer:

import torch


class Foo(torch.nn.Module):
    def forward(self, tensor):
        # It is data dependent
        # Trace will only work with one path
        if tensor.max() > 0.5:
            return tensor ** 2
        return tensor


model = Foo()
traced = torch.jit.script(model) # No warnings
traced = torch.jit.trace(model, torch.randn(10)) # Warning

In essence, BERT model has some control flow (like if, for loop) dependent on the data, hence you get the warning.

Warning itself

You can see BERT forward code here.

You are fine if:

  • arguments do not change (like None values passed to forward) and it will stay that way after script (e.g. during inference calls)
  • if there is control flow based on data gathered inside __init__ (like configs), because this will not change

For example:

elif input_ids is not None:
    input_shape = input_ids.size()
    batch_size, seq_length = input_shape

Will only run as one branch with torch.jit.trace, as it just traces operations on tensor and is unaware of control flow like this.

HuggingFace teams is probably aware of that and this warning is not an issue (though you might double check with your use case or try to go with torch.jit.script)

Going with torch.jit.script

This one would be hard as the whole model has to be torchscript compatible (torchscript has a subset of Python available and more than likely will not work out of the box with BERT).

Do it only when necessary (probably not).

like image 130
Szymon Maszke Avatar answered May 17 '23 08:05

Szymon Maszke