I am following this tutorial: https://huggingface.co/transformers/torchscript.html
to create a trace of my custom BERT model, however when running the exact same dummy_input
I receive an error:
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect.
We cant record the data flow of Python values, so this value will be treated as a constant in the future.
Having loaded in my model and tokenizer, the code to create the trace is the following:
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = tokenizer.tokenize(text)
# Masking one of the input tokens
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors]
traced_model = torch.jit.trace(model, dummy_input)
The dummy_input
is a list of tensors so I'm not sure where the Boolean
type is coming into play here. Does anyone understand why this error is occurring and whether the Boolean conversion is happening?
Many Thanks
This warning occurs, when one tries to torch.jit.trace
models which have data dependent control flow.
This simple example should make it clearer:
import torch
class Foo(torch.nn.Module):
def forward(self, tensor):
# It is data dependent
# Trace will only work with one path
if tensor.max() > 0.5:
return tensor ** 2
return tensor
model = Foo()
traced = torch.jit.script(model) # No warnings
traced = torch.jit.trace(model, torch.randn(10)) # Warning
In essence, BERT model has some control flow (like if
, for
loop) dependent on the data, hence you get the warning.
You can see BERT forward
code here.
You are fine if:
None
values passed to forward
) and it will stay that way after script
(e.g. during inference calls)__init__
(like configs), because this will not changeFor example:
elif input_ids is not None:
input_shape = input_ids.size()
batch_size, seq_length = input_shape
Will only run as one branch with torch.jit.trace
, as it just traces operations on tensor and is unaware of control flow like this.
HuggingFace teams is probably aware of that and this warning is not an issue (though you might double check with your use case or try to go with torch.jit.script
)
torch.jit.script
This one would be hard as the whole model has to be torchscript
compatible (torchscript
has a subset of Python available and more than likely will not work out of the box with BERT).
Do it only when necessary (probably not).
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With