Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to extract paches from 3D image in python?

I have a 3D image with size: Deep x Weight x Height (for example: 10x20x30, means 10 images, and each image has size 20x30.

Given a patch size is pd x pw x ph (such as pd <Deep, pw<Weight, ph<Height), for example patch size: 4x4x4. The center point location of the path will be: pd/2 x pw/2 x ph/2. Let's call the distance between time t and time t+1 of the center point be stride, for example stride=2.

I want to extract the original 3D image into patches with size and stride given above. How can I do it in python? Thank you

.

enter image description here

like image 828
user3051460 Avatar asked Feb 20 '17 05:02

user3051460


People also ask

How do I split an image into a patch in Python?

if image is grayscale, the method returns a NumPy array with shape (n_rows, n_cols, 1, H, W), where n_rows is the number of patches for each column and n_cols is the number of patches for each row. if image is N-channels, the method returns a NumPy array with shape (n_rows, n_cols, 1, H, W, N).

How do I split an image into multiple pieces in Python?

Image. split() method is used to split the image into individual bands. This method returns a tuple of individual image bands from an image.

What is Patchify?

patchify 0.2. 3A library that helps you split image into small, overlappable patches, and merge patches back into the original image.


1 Answers

Use np.lib.stride_tricks.as_strided. This solution does not require the strides to divide the corresponding dimensions of the input stack. It even allows for overlapping patches (Just do not write to the result in this case, or make a copy.). It therefore is more flexible than other approaches:

import numpy as np
from numpy.lib import stride_tricks

def cutup(data, blck, strd):
    sh = np.array(data.shape)
    blck = np.asanyarray(blck)
    strd = np.asanyarray(strd)
    nbl = (sh - blck) // strd + 1
    strides = np.r_[data.strides * strd, data.strides]
    dims = np.r_[nbl, blck]
    data6 = stride_tricks.as_strided(data, strides=strides, shape=dims)
    return data6#.reshape(-1, *blck)

#demo
x = np.zeros((5, 6, 12), int)
y = cutup(x, (2, 2, 3), (3, 3, 5))

y[...] = 1
print(x[..., 0], '\n')
print(x[:, 0, :], '\n')
print(x[0, ...], '\n')

Output:

[[1 1 0 1 1 0]
 [1 1 0 1 1 0]
 [0 0 0 0 0 0]
 [1 1 0 1 1 0]
 [1 1 0 1 1 0]] 

[[1 1 1 0 0 1 1 1 0 0 0 0]
 [1 1 1 0 0 1 1 1 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0]
 [1 1 1 0 0 1 1 1 0 0 0 0]
 [1 1 1 0 0 1 1 1 0 0 0 0]] 

[[1 1 1 0 0 1 1 1 0 0 0 0]
 [1 1 1 0 0 1 1 1 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0]
 [1 1 1 0 0 1 1 1 0 0 0 0]
 [1 1 1 0 0 1 1 1 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0]] 

Explanation. Numpy arrays are organised in terms of strides, one for each dimension, data point [x,y,z] is located in memory at address base + stridex * x + stridey * y + stridez * z.

The stride_tricks.as_strided factory allows to directly manipulate the strides and shape of a new array sharing its memory with a given array. Try this only if you know what you're doing because no checks are performed, meaning you are allowed to shoot your foot by addressing out-of-bounds memory.

The code uses this function to split up each of the three existing dimensions into two new ones, one for the corresponding within block coordinate (this will have the same stride as the original dimension, because adjacent points in a block corrspond to adjacent points in the whole stack) and one dimension for the block index along this axis; this will have stride = original stride x block stride.

All the code does is computing the correct strides and dimensions (= block dimensions and block counts along the three axes).

Since the data are shared with the original array, when we set all points of the 6d array to 1, they are also set in the original array exposing the block structure in the demo. Note that the commented out reshape in the last line of the function breaks this link, because it forces a copy.

like image 83
Paul Panzer Avatar answered Sep 30 '22 20:09

Paul Panzer