Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to handle odd resolutions in Unet architecture PyTorch

I'm implementing a U-Net based architecture in PyTorch. At train time, I've patches of size 256x256 which doesn't cause any problem. However at test time, I've full HD images (1920x1080). This is causing a problem during skip connections.

Downsampling 1920x1080 3 times gives 240x135. If I downsample one more time, the resolution becomes 120x68 which when upsampled gives 240x136. Now, I cannot concatenate these two feature maps. How can I solve this?

PS: I thought this is a fairly common problem, but I didn't get any solution or even mentioning of this problem anywhere on the web. Am I missing something?

like image 911
Nagabhushan S N Avatar asked Feb 03 '21 13:02

Nagabhushan S N


People also ask

What is UNET architecture in PyTorch?

This tutorial focus on the implementation of the image segmentation architecture called UNET in the PyTorch framework. It’s a simple encoder-decoder architecture developed by Olaf Ronneberger et al. for Biomedical Image Segmentation in 2015 at the University of Freiburg, Germany. What is Image Segmentation?

Can you create unets for different input image sizes?

This is rather a theoretical question, but I like to know if one can create UNets for different input image sizes. I already think I know the UNet input size may not match the output size. That should be possible. If you stick to certain spatial sizes, e.g. powers of two, I would assume that most UNet implementations can handle the input.

Why build a U-Net from scratch with PyTorch?

In this part, we focus on building a U-Net from scratch with the PyTorch library. The goal is to implement the U-Net in such a way, that important model configurations such as the activation function or the depth can be passed as arguments when creating the model.

Can UNET handle input and output shapes?

If you stick to certain spatial sizes, e.g. powers of two, I would assume that most UNet implementations can handle the input. The input shape and output shape in the original UNet paper are not matching, if I recall it correctly, but you could implement it in a matching way.


1 Answers

It is a very common problem in segmentation networks where skip-connections are often involved in the decoding process. Networks usually (depending on the actual architecture) require input size that has side lengths as integer multiples of the largest stride (8, 16, 32, etc.).

There are two main ways:

  1. Resize input to the nearest feasible size.
  2. Pad the input to the next larger feasible size.

I prefer (2) because (1) can cause small changes in the pixel level for all the pixels, leading to unnecessary blurriness. Note that we usually need to recover the original shape afterward in both methods.

My favorite code snippet for this task (symmetric padding for height/width):

import torch
import torch.nn.functional as F

def pad_to(x, stride):
    h, w = x.shape[-2:]

    if h % stride > 0:
        new_h = h + stride - h % stride
    else:
        new_h = h
    if w % stride > 0:
        new_w = w + stride - w % stride
    else:
        new_w = w
    lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2)
    lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2)
    pads = (lw, uw, lh, uh)

    # zero-padding by default.
    # See others at https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.pad
    out = F.pad(x, pads, "constant", 0)

    return out, pads

def unpad(x, pad):
    if pad[2]+pad[3] > 0:
        x = x[:,:,pad[2]:-pad[3],:]
    if pad[0]+pad[1] > 0:
        x = x[:,:,:,pad[0]:-pad[1]]
    return x

A test snippet:

x = torch.zeros(4, 3, 1080, 1920) # Raw data
x_pad, pads = pad_to(x, 16) # Padded data, feed this to your network 
x_unpad = unpad(x_pad, pads) # Un-pad the network output to recover the original shape

print('Original: ', x.shape)
print('Padded: ', x_pad.shape)
print('Recovered: ', x_unpad.shape)

Output:

Original:  torch.Size([4, 3, 1080, 1920])
Padded:  torch.Size([4, 3, 1088, 1920])
Recovered:  torch.Size([4, 3, 1080, 1920])

Reference: https://github.com/seoungwugoh/STM/blob/905f11492a6692dd0d0fa395881a8ec09b211a36/helpers.py#L33

like image 58
hkchengrex Avatar answered Oct 21 '22 14:10

hkchengrex