I'm implementing a U-Net based architecture in PyTorch. At train time, I've patches of size 256x256
which doesn't cause any problem. However at test time, I've full HD images (1920x1080
). This is causing a problem during skip connections.
Downsampling 1920x1080
3 times gives 240x135
. If I downsample one more time, the resolution becomes 120x68
which when upsampled gives 240x136
. Now, I cannot concatenate these two feature maps. How can I solve this?
PS: I thought this is a fairly common problem, but I didn't get any solution or even mentioning of this problem anywhere on the web. Am I missing something?
This tutorial focus on the implementation of the image segmentation architecture called UNET in the PyTorch framework. It’s a simple encoder-decoder architecture developed by Olaf Ronneberger et al. for Biomedical Image Segmentation in 2015 at the University of Freiburg, Germany. What is Image Segmentation?
This is rather a theoretical question, but I like to know if one can create UNets for different input image sizes. I already think I know the UNet input size may not match the output size. That should be possible. If you stick to certain spatial sizes, e.g. powers of two, I would assume that most UNet implementations can handle the input.
In this part, we focus on building a U-Net from scratch with the PyTorch library. The goal is to implement the U-Net in such a way, that important model configurations such as the activation function or the depth can be passed as arguments when creating the model.
If you stick to certain spatial sizes, e.g. powers of two, I would assume that most UNet implementations can handle the input. The input shape and output shape in the original UNet paper are not matching, if I recall it correctly, but you could implement it in a matching way.
It is a very common problem in segmentation networks where skip-connections are often involved in the decoding process. Networks usually (depending on the actual architecture) require input size that has side lengths as integer multiples of the largest stride (8, 16, 32, etc.).
There are two main ways:
I prefer (2) because (1) can cause small changes in the pixel level for all the pixels, leading to unnecessary blurriness. Note that we usually need to recover the original shape afterward in both methods.
My favorite code snippet for this task (symmetric padding for height/width):
import torch
import torch.nn.functional as F
def pad_to(x, stride):
h, w = x.shape[-2:]
if h % stride > 0:
new_h = h + stride - h % stride
else:
new_h = h
if w % stride > 0:
new_w = w + stride - w % stride
else:
new_w = w
lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2)
lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2)
pads = (lw, uw, lh, uh)
# zero-padding by default.
# See others at https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.pad
out = F.pad(x, pads, "constant", 0)
return out, pads
def unpad(x, pad):
if pad[2]+pad[3] > 0:
x = x[:,:,pad[2]:-pad[3],:]
if pad[0]+pad[1] > 0:
x = x[:,:,:,pad[0]:-pad[1]]
return x
A test snippet:
x = torch.zeros(4, 3, 1080, 1920) # Raw data
x_pad, pads = pad_to(x, 16) # Padded data, feed this to your network
x_unpad = unpad(x_pad, pads) # Un-pad the network output to recover the original shape
print('Original: ', x.shape)
print('Padded: ', x_pad.shape)
print('Recovered: ', x_unpad.shape)
Output:
Original: torch.Size([4, 3, 1080, 1920])
Padded: torch.Size([4, 3, 1088, 1920])
Recovered: torch.Size([4, 3, 1080, 1920])
Reference: https://github.com/seoungwugoh/STM/blob/905f11492a6692dd0d0fa395881a8ec09b211a36/helpers.py#L33
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With