Pytorch offers torch.Tensor.unfold operation which can be chained to arbitrarily many dimensions to extract overlapping patches. How can we reverse the patch extraction operation such that the patches are combined to the input shape.
The focus is 3D volumetric images with 1 channel (biomedical). Extracting is possible with unfold, how can we combine the patches if they overlap.
The above solution makes copies in memory as it keeps the patches contiguous. This leads to memory issues for large volumes with many overlapping voxels. To extract patches without making a copy in memory we can do the following in pytorch:
def get_dim_blocks(dim_in, kernel_size, padding=0, stride=1, dilation=1):
    return (dim_in + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
def extract_patches_3d(x, kernel_size, stride=1, dilation=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
    if isinstance(stride, int):
        stride = (stride, stride, stride)
    if isinstance(dilation, int):
        dilation = (dilation, dilation, dilation)
    x = x.contiguous()
    channels, depth, height, width = x.shape[-4:]
    d_blocks = get_dim_blocks(depth, kernel_size=kernel_size[0], stride=stride[0], dilation=dilation[0])
    h_blocks = get_dim_blocks(height, kernel_size=kernel_size[1], stride=stride[1], dilation=dilation[1])
    w_blocks = get_dim_blocks(width, kernel_size=kernel_size[2], stride=stride[2], dilation=dilation[2])
    shape = (channels, d_blocks, h_blocks, w_blocks, kernel_size[0], kernel_size[1], kernel_size[2])
    strides = (width*height*depth,
               stride[0]*width*height, 
               stride[1]*width, 
               stride[2], 
               dilation[0]*width*height, 
               dilation[1]*width,
               dilation[2])
    x = x.as_strided(shape, strides)
    x = x.permute(1,2,3,0,4,5,6)
    return x
The method expect tensor in shape `(B,C,D,H,W). The method is based on this and this answer (in numpy) which explain in more detail what memory stride does. The output will be non-contiguous and the first 3 dimensions will be the number of blocks or sliding windows in the D, H and W dimension. Combining into 1 dimension is not possible as this would require a copy to contiguous memory.
Test with stride
a = torch.arange(81, dtype=torch.float32).view(1,3,3,3,3)
print(a)
b = extract_patches_3d(a, kernel_size=2, stride=2)
print(b.shape)
print(b.storage())
print(a.data_ptr() == b.data_ptr())
print(b)
Output
tensor([[[[[ 0.,  1.,  2.],
           [ 3.,  4.,  5.],
           [ 6.,  7.,  8.]],
          [[ 9., 10., 11.],
           [12., 13., 14.],
           [15., 16., 17.]],
          [[18., 19., 20.],
           [21., 22., 23.],
           [24., 25., 26.]]],
         [[[27., 28., 29.],
           [30., 31., 32.],
           [33., 34., 35.]],
          [[36., 37., 38.],
           [39., 40., 41.],
           [42., 43., 44.]],
          [[45., 46., 47.],
           [48., 49., 50.],
           [51., 52., 53.]]],
         [[[54., 55., 56.],
           [57., 58., 59.],
           [60., 61., 62.]],
          [[63., 64., 65.],
           [66., 67., 68.],
           [69., 70., 71.]],
          [[72., 73., 74.],
           [75., 76., 77.],
           [78., 79., 80.]]]]])
torch.Size([1, 1, 1, 3, 2, 2, 2])
 0.0
 1.0
 2.0
 3.0
 4.0
 5.0
 6.0
 7.0
 8.0
 9.0
 10.0
 11.0
 12.0
 13.0
 14.0
 15.0
 16.0
 17.0
 18.0
 19.0
 20.0
 21.0
 22.0
 23.0
 24.0
 25.0
 26.0
 27.0
 28.0
 29.0
 30.0
 31.0
 32.0
 33.0
 34.0
 35.0
 36.0
 37.0
 38.0
 39.0
 40.0
 41.0
 42.0
 43.0
 44.0
 45.0
 46.0
 47.0
 48.0
 49.0
 50.0
 51.0
 52.0
 53.0
 54.0
 55.0
 56.0
 57.0
 58.0
 59.0
 60.0
 61.0
 62.0
 63.0
 64.0
 65.0
 66.0
 67.0
 68.0
 69.0
 70.0
 71.0
 72.0
 73.0
 74.0
 75.0
 76.0
 77.0
 78.0
 79.0
 80.0
[torch.FloatStorage of size 81]
True
tensor([[[[[[[ 0.,  1.],
             [ 3.,  4.]],
            [[ 9., 10.],
             [12., 13.]]],
           [[[27., 28.],
             [30., 31.]],
            [[36., 37.],
             [39., 40.]]],
           [[[54., 55.],
             [57., 58.]],
            [[63., 64.],
             [66., 67.]]]]]]])
Reversing with summation of overlapping voxels using memory stride is not possible assuming that the tensor is contiguous (as it would be after processing in NN). However you can manually sum them as explained as above, or with slicing as explained here.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With