Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TorchScript with model containing multiple heads

Tags:

python

torch

My goal is to serialize a pytorch trained model an load it in an environment where the original class defining the neural network is not available. To achieve that, I decided to use TorchScript since it seems the only possible way.

I have a multi-task model (type nn.Module) built using a body common to every task (also nn.Module, just a few linear layers) and a set of linear head models, one per task. I store the head models in a dictionary Dict[int, nn.Module] called _task_head_models and I created an ad-hoc forward method in my module class to select the right head at prediction time:

    def forward(self, x: torch.Tensor, task_id: int = 0) -> torch.Tensor:
        if task_id not in self._task_head_models.keys():
            raise ValueError(
                f"The task id {task_id} is not valid. Valid task ids are {self._task_head_models.keys()}."
            )

        return self._task_head_models[task_id](self._model(x))

This works fine until I am not trying to serialize it using torchscript. When I try torch.jit.script(mymodule), I get:

Module 'MyModule' has no attribute '_task_head_models' (This attribute exists on the Python module, but we failed to convert Python type: 'dict' to a TorchScript type. Cannot infer concrete type of torch.nn.Module. Its type was inferred; try adding a type annotation for the attribute.)

Something that seems off, is that my module contains a Dict, not a dict as mentioned in the error message. Forgetting that for a second, it's still unclear why this is happening. Dictionaries seems to be supported in the language reference: https://docs.w3cub.com/pytorch/jit_language_reference.html

I also tried to use ModuleDict instead of Dict (changing the key type to str) but that doesn't seem to work either: Unable to extract string literal index. ModuleDict indexing is only supported with string literals. Enumeration of ModuleDict is supported, e.g. 'for k, v in self.items(): ...':

like image 751
giz Avatar asked Dec 20 '25 16:12

giz


1 Answers

If there are not many items in theDict _task_head_models, I think using the if-else branch can help you. The sample code is as follows:

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._task_head0 = torch.nn.Linear(3, 24)
        self._task_head1 = torch.nn.Linear(3, 24)

    def forward(self, x: torch.Tensor, task_id: int = 0) -> torch.Tensor:
      if task_id == 0:
          return self._task_head0(x)
      elif task_id == 1:
          return self._task_head1(x)
      else:
          raise ValueError(
                f"The task id {task_id} is not valid. Valid task ids are 0,1."
            )
like image 66
dipper Avatar answered Dec 23 '25 07:12

dipper



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!