Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

tensorflow's Timedistributed equivalent in pyTorch

Is there any equivalent implementation of tensorflow.keras.layers.Timedistributed for pytorch?

I am trying to build something like Timedistributed(Resnet50()).

like image 609
Uiyun Kim Avatar asked Mar 03 '23 06:03

Uiyun Kim


1 Answers

Credit to miguelvr on this topic.

You can use this code which is a PyTorch module developed to mimic the Timeditributed wrapper.

import torch.nn as nn

class TimeDistributed(nn.Module):
    def __init__(self, module, batch_first=False):
        super(TimeDistributed, self).__init__()
        self.module = module
        self.batch_first = batch_first

    def forward(self, x):

        if len(x.size()) <= 2:
            return self.module(x)

        # Squash samples and timesteps into a single axis
        x_reshape = x.contiguous().view(-1, x.size(-1))  # (samples * timesteps, input_size)

        y = self.module(x_reshape)

        # We have to reshape Y
        if self.batch_first:
            y = y.contiguous().view(x.size(0), -1, y.size(-1))  # (samples, timesteps, output_size)
        else:
            y = y.view(-1, x.size(1), y.size(-1))  # (timesteps, samples, output_size)

        return y
like image 178
Bando Avatar answered Apr 28 '23 06:04

Bando