Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to extract overlapping patches from a 3D volume and recreate the input shape from the patches?

Pytorch offers torch.Tensor.unfold operation which can be chained to arbitrarily many dimensions to extract overlapping patches. How can we reverse the patch extraction operation such that the patches are combined to the input shape.

The focus is 3D volumetric images with 1 channel (biomedical). Extracting is possible with unfold, how can we combine the patches if they overlap.

like image 724
blanNL Avatar asked Oct 31 '25 21:10

blanNL


1 Answers

The above solution makes copies in memory as it keeps the patches contiguous. This leads to memory issues for large volumes with many overlapping voxels. To extract patches without making a copy in memory we can do the following in pytorch:

def get_dim_blocks(dim_in, kernel_size, padding=0, stride=1, dilation=1):
    return (dim_in + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1

def extract_patches_3d(x, kernel_size, stride=1, dilation=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
    if isinstance(stride, int):
        stride = (stride, stride, stride)
    if isinstance(dilation, int):
        dilation = (dilation, dilation, dilation)
    x = x.contiguous()

    channels, depth, height, width = x.shape[-4:]
    d_blocks = get_dim_blocks(depth, kernel_size=kernel_size[0], stride=stride[0], dilation=dilation[0])
    h_blocks = get_dim_blocks(height, kernel_size=kernel_size[1], stride=stride[1], dilation=dilation[1])
    w_blocks = get_dim_blocks(width, kernel_size=kernel_size[2], stride=stride[2], dilation=dilation[2])
    shape = (channels, d_blocks, h_blocks, w_blocks, kernel_size[0], kernel_size[1], kernel_size[2])
    strides = (width*height*depth,
               stride[0]*width*height, 
               stride[1]*width, 
               stride[2], 
               dilation[0]*width*height, 
               dilation[1]*width,
               dilation[2])

    x = x.as_strided(shape, strides)
    x = x.permute(1,2,3,0,4,5,6)
    return x

The method expect tensor in shape `(B,C,D,H,W). The method is based on this and this answer (in numpy) which explain in more detail what memory stride does. The output will be non-contiguous and the first 3 dimensions will be the number of blocks or sliding windows in the D, H and W dimension. Combining into 1 dimension is not possible as this would require a copy to contiguous memory.

Test with stride

a = torch.arange(81, dtype=torch.float32).view(1,3,3,3,3)
print(a)
b = extract_patches_3d(a, kernel_size=2, stride=2)
print(b.shape)
print(b.storage())
print(a.data_ptr() == b.data_ptr())
print(b)

Output

tensor([[[[[ 0.,  1.,  2.],
           [ 3.,  4.,  5.],
           [ 6.,  7.,  8.]],

          [[ 9., 10., 11.],
           [12., 13., 14.],
           [15., 16., 17.]],

          [[18., 19., 20.],
           [21., 22., 23.],
           [24., 25., 26.]]],


         [[[27., 28., 29.],
           [30., 31., 32.],
           [33., 34., 35.]],

          [[36., 37., 38.],
           [39., 40., 41.],
           [42., 43., 44.]],

          [[45., 46., 47.],
           [48., 49., 50.],
           [51., 52., 53.]]],


         [[[54., 55., 56.],
           [57., 58., 59.],
           [60., 61., 62.]],

          [[63., 64., 65.],
           [66., 67., 68.],
           [69., 70., 71.]],

          [[72., 73., 74.],
           [75., 76., 77.],
           [78., 79., 80.]]]]])
torch.Size([1, 1, 1, 3, 2, 2, 2])
 0.0
 1.0
 2.0
 3.0
 4.0
 5.0
 6.0
 7.0
 8.0
 9.0
 10.0
 11.0
 12.0
 13.0
 14.0
 15.0
 16.0
 17.0
 18.0
 19.0
 20.0
 21.0
 22.0
 23.0
 24.0
 25.0
 26.0
 27.0
 28.0
 29.0
 30.0
 31.0
 32.0
 33.0
 34.0
 35.0
 36.0
 37.0
 38.0
 39.0
 40.0
 41.0
 42.0
 43.0
 44.0
 45.0
 46.0
 47.0
 48.0
 49.0
 50.0
 51.0
 52.0
 53.0
 54.0
 55.0
 56.0
 57.0
 58.0
 59.0
 60.0
 61.0
 62.0
 63.0
 64.0
 65.0
 66.0
 67.0
 68.0
 69.0
 70.0
 71.0
 72.0
 73.0
 74.0
 75.0
 76.0
 77.0
 78.0
 79.0
 80.0
[torch.FloatStorage of size 81]
True
tensor([[[[[[[ 0.,  1.],
             [ 3.,  4.]],

            [[ 9., 10.],
             [12., 13.]]],


           [[[27., 28.],
             [30., 31.]],

            [[36., 37.],
             [39., 40.]]],


           [[[54., 55.],
             [57., 58.]],

            [[63., 64.],
             [66., 67.]]]]]]])


Reversing with summation of overlapping voxels using memory stride is not possible assuming that the tensor is contiguous (as it would be after processing in NN). However you can manually sum them as explained as above, or with slicing as explained here.

like image 142
blanNL Avatar answered Nov 04 '25 03:11

blanNL



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!