I am building a Variational Autoencoder (VAE) in PyTorch and have a problem writing device agnostic code. The Autoencoder is a child of nn.Module
with an encoder and decoder network, which are too. All weights of the network can be moved from one device to another by calling net.to(device)
.
The problem I have is with the reparametrization trick:
encoding = mu + noise * sigma
The noise is a tensor of the same size as mu
and sigma
and saved as a member variable of the autoencoder module. It is initialized in the constructor and resampled in-place each training step. I do it that way to avoid constructing a new noise tensor each step and pushing it to the desired device. Additionally, I want to fix the noise in the evaluation. Here is the code:
class VariationalGenerator(nn.Module):
def __init__(self, input_nc, output_nc):
super(VariationalGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
embedding_size = 128
self._train_noise = torch.randn(batch_size, embedding_size)
self._eval_noise = torch.randn(1, embedding_size)
self.noise = self._train_noise
# Create encoder
self.encoder = Encoder(input_nc, embedding_size)
# Create decoder
self.decoder = Decoder(output_nc, embedding_size)
def train(self, mode=True):
super(VariationalGenerator, self).train(mode)
self.noise = self._train_noise
def eval(self):
super(VariationalGenerator, self).eval()
self.noise = self._eval_noise
def forward(self, inputs):
# Calculate parameters of embedding space
mu, log_sigma = self.encoder.forward(inputs)
# Resample noise if training
if self.training:
self.noise.normal_()
# Reparametrize noise to embedding space
inputs = mu + self.noise * torch.exp(0.5 * log_sigma)
# Decode to image
inputs = self.decoder(inputs)
return inputs, mu, log_sigma
When I now move the autoencoder to the GPU with net.to('cuda:0')
I get an error in forwarding because the noise tensor is not moved.
I don't want to add a device parameter to the constructor, because then it is still not possible to move it to another device later. I also tried to wrap the noise into nn.Parameter
so that it is affected by net.to()
, but that gives an error from the optimizer, as the noise is flagged as requires_grad=False
.
Anyone has a solution to move all of the modules with net.to()
?
A better version of tilman151's second approach is probably to override _apply
, rather than to
. That way net.cuda()
, net.float()
, etc will all work as well, since those all call _apply
rather than to
(as can be seen in the source, which is simpler than you might think):
def _apply(self, fn):
super(VariationalGenerator, self)._apply(fn)
self._train_noise = fn(self._train_noise)
self._eval_noise = fn(self._eval_noise)
return self
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With