I have about 5 models that work pretty well trained individually but I want to fuse them together in order to have one big model. I'm looking into it because one big model is more easy to update (in production) than many small model this is an image of what I want to achieve.
my question are, is it ok to do it like this ? having one dataset per head model, how am I supposed to train the whole model ?
In the multi-head model, the output layer allocates a different set of output units (head) for each task. The single-head model use the same output layer for each task. Best viewed in color.
Multi-task learning, on the other hand, is a machine learning approach in which we try to learn multiple tasks simultaneously, optimizing multiple loss functions at once. Rather than training independent models for each task, we allow a single model to learn to complete all of the tasks at once.
Multi-task learning (MTL) is a field of machine learning in which models using data from multiple tasks are trained at the same time. This is done using shared representations to uncover the common ideas among a group of tasks that are connected.
my question are, is it ok to do it like this
Sure you can do that. This approach is called multi-task learning. Depending on your datasets and what you are trying to do, it will maybe even increase the performance. Microsoft used a multi-task model to achieve some good results for the NLP Glue benchmark, but they also noted that you can increase the performance further by finetuning the joint model for each individual task.
having one dataset per head model, how am I supposed to train the whole model?
All you need is pytorch ModuleList:
#please note this is just pseudocode and I'm not well versed with computer vision
#therefore you need to check if resnet50 import is correct and look
#for the imports of the task specific stuff
from torch import nn
from torchvision.models import resnet50
class MultiTaskModel(nn.Module):
def __init__(self):
#shared part
self.resnet50 = resnet50()
#task specific stuff
self.tasks = nn.ModuleList()
self.tasks.add_module('depth', Depth())
self.tasks.add_module('denseflow', Denseflow())
#...
def forward(self, tasktag, ...):
#shared part
resnet_output = self.resnet50(...)
#task specific parts
if tasktag == 'depth':
return self.tasks.depth(resnet_output)
elif tasktag == 'denseflow':
return self.tasks.denseflow(resnet_output)
#...
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With