Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Moving member tensors with module.to() in PyTorch

I am building a Variational Autoencoder (VAE) in PyTorch and have a problem writing device agnostic code. The Autoencoder is a child of nn.Module with an encoder and decoder network, which are too. All weights of the network can be moved from one device to another by calling net.to(device).

The problem I have is with the reparametrization trick:

encoding = mu + noise * sigma

The noise is a tensor of the same size as mu and sigma and saved as a member variable of the autoencoder module. It is initialized in the constructor and resampled in-place each training step. I do it that way to avoid constructing a new noise tensor each step and pushing it to the desired device. Additionally, I want to fix the noise in the evaluation. Here is the code:

class VariationalGenerator(nn.Module):
    def __init__(self, input_nc, output_nc):
        super(VariationalGenerator, self).__init__()

        self.input_nc = input_nc
        self.output_nc = output_nc
        embedding_size = 128

        self._train_noise = torch.randn(batch_size, embedding_size)
        self._eval_noise = torch.randn(1, embedding_size)
        self.noise = self._train_noise

        # Create encoder
        self.encoder = Encoder(input_nc, embedding_size)
        # Create decoder
        self.decoder = Decoder(output_nc, embedding_size)

    def train(self, mode=True):
        super(VariationalGenerator, self).train(mode)
        self.noise = self._train_noise

    def eval(self):
        super(VariationalGenerator, self).eval()
        self.noise = self._eval_noise

    def forward(self, inputs):
        # Calculate parameters of embedding space
        mu, log_sigma = self.encoder.forward(inputs)
        # Resample noise if training
        if self.training:
            self.noise.normal_()
        # Reparametrize noise to embedding space
        inputs = mu + self.noise * torch.exp(0.5 * log_sigma)
        # Decode to image
        inputs = self.decoder(inputs)

        return inputs, mu, log_sigma

When I now move the autoencoder to the GPU with net.to('cuda:0') I get an error in forwarding because the noise tensor is not moved.

I don't want to add a device parameter to the constructor, because then it is still not possible to move it to another device later. I also tried to wrap the noise into nn.Parameter so that it is affected by net.to(), but that gives an error from the optimizer, as the noise is flagged as requires_grad=False.

Anyone has a solution to move all of the modules with net.to()?

like image 522
tilman151 Avatar asked Dec 06 '22 10:12

tilman151


1 Answers

A better version of tilman151's second approach is probably to override _apply, rather than to. That way net.cuda(), net.float(), etc will all work as well, since those all call _apply rather than to (as can be seen in the source, which is simpler than you might think):

def _apply(self, fn):
    super(VariationalGenerator, self)._apply(fn)
    self._train_noise = fn(self._train_noise)
    self._eval_noise = fn(self._eval_noise)
    return self
like image 147
Danica Avatar answered Jan 13 '23 21:01

Danica