Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to Multi-Head learning

I have about 5 models that work pretty well trained individually but I want to fuse them together in order to have one big model. I'm looking into it because one big model is more easy to update (in production) than many small model this is an image of what I want to achieve. enter image description here

my question are, is it ok to do it like this ? having one dataset per head model, how am I supposed to train the whole model ?

like image 259
PrinceZee Avatar asked Nov 20 '19 08:11

PrinceZee


People also ask

What is a multi-head model?

In the multi-head model, the output layer allocates a different set of output units (head) for each task. The single-head model use the same output layer for each task. Best viewed in color.

How does multitask learning work?

Multi-task learning, on the other hand, is a machine learning approach in which we try to learn multiple tasks simultaneously, optimizing multiple loss functions at once. Rather than training independent models for each task, we allow a single model to learn to complete all of the tasks at once.

What is multitask deep learning?

Multi-task learning (MTL) is a field of machine learning in which models using data from multiple tasks are trained at the same time. This is done using shared representations to uncover the common ideas among a group of tasks that are connected.


1 Answers

my question are, is it ok to do it like this

Sure you can do that. This approach is called multi-task learning. Depending on your datasets and what you are trying to do, it will maybe even increase the performance. Microsoft used a multi-task model to achieve some good results for the NLP Glue benchmark, but they also noted that you can increase the performance further by finetuning the joint model for each individual task.

having one dataset per head model, how am I supposed to train the whole model?

All you need is pytorch ModuleList:

#please note this is just pseudocode and I'm not well versed with computer vision
#therefore you need to check if resnet50 import is correct and look 
#for the imports of the task specific stuff
from torch import nn
from torchvision.models import resnet50

class MultiTaskModel(nn.Module):
    def __init__(self):
        #shared part
        self.resnet50 = resnet50()

        #task specific stuff
        self.tasks = nn.ModuleList()
        self.tasks.add_module('depth', Depth())
        self.tasks.add_module('denseflow', Denseflow())
        #...

    def forward(self, tasktag, ...):
        #shared part
        resnet_output = self.resnet50(...)

        #task specific parts
        if tasktag == 'depth':
            return self.tasks.depth(resnet_output)
        elif tasktag == 'denseflow':
            return self.tasks.denseflow(resnet_output)
        #...
like image 188
cronoik Avatar answered Oct 15 '22 22:10

cronoik