I'm reading an implementation of im2col from a deep learning book(At chapter 7, CNN), which its purpose is to transform a 4 dimensional array into 2 dimensional. I don't know why there is a 6 dimensional array in the implementation. I'm very interested about what's the idea behind the algorithm the author used.
I've tried to search many papers of the implementation of im2col, but none of them using high dimensional array like this. The currently materials I found useful for visualization of the process of im2col is the picture of this paper - HAL Id: inria-00112631
def im2col(input_data, filter_h, filter_w, stride=1, pad=0):
"""
Parameters
----------
input_data : (batch size, channel, height, width), or (N,C,H,W) at below
filter_h : kernel height
filter_w : kernel width
stride : size of stride
pad : size of padding
Returns
-------
col : two dimensional array
"""
N, C, H, W = input_data.shape
out_h = (H + 2*pad - filter_h)//stride + 1
out_w = (W + 2*pad - filter_w)//stride + 1
img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))
for y in range(filter_h):
y_max = y + stride*out_h
for x in range(filter_w):
x_max = x + stride*out_w
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]
col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
return col
Let's try and visualize what im2col
does. It takes as its input a stack of color images, the stack has dimensions image id, color channel, vertical position, horizontal position. Let's for simplicity assume we have only one image:
First thing it does is padding:
Next, it cuts it up into windows. The size of the windows is controlled by filter_h/w
, the overlap by strides
.
This is where the six dimensions are coming from: image id (missing in the example because we have only one image), grid height/width, color channel. window height/width.
The algorithm as it stands is a bit clumsy, it assembles the output in the wrong order of dimensions and then has to correct that using transpose
.
Better to get it right in the first place:
def im2col_better(input_data, filter_h, filter_w, stride=1, pad=0):
img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
N, C, H, W = img.shape
out_h = (H - filter_h)//stride + 1
out_w = (W - filter_w)//stride + 1
col = np.zeros((N, out_h, out_w, C, filter_h, filter_w))
for y in range(out_h):
for x in range(out_w):
col[:, y, x] = img[
..., y*stride:y*stride+filter_h, x*stride:x*stride+filter_w]
return col.reshape(np.multiply.reduceat(col.shape, (0, 3)))
Just as a side note: We can do even better using stride_tricks
and avoid the nested for
loop:
def im2col_best(input_data, filter_h, filter_w, stride=1, pad=0):
img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
N, C, H, W = img.shape
NN, CC, HH, WW = img.strides
out_h = (H - filter_h)//stride + 1
out_w = (W - filter_w)//stride + 1
col = np.lib.stride_tricks.as_strided(img, (N, out_h, out_w, C, filter_h, filter_w), (NN, stride * HH, stride * WW, CC, HH, WW)).astype(float)
return col.reshape(np.multiply.reduceat(col.shape, (0, 3)))
The last thing the algorithm does is reshaping, merging the first three dimensions (again only two in our example because only one image). The red arrows show how the individual windows are lined up into the first new dimension:
The last three dimensions color channel, y coordinate in window, x coordinate in window are merged into the second output dimension. Individual pixels are lined up as indicated by the yellow arrows:
It looks like this function is just rearranging each of the C
colour-channels in each of the N
images into a (out_h x out_w)
grid of overlapping image patches of size (filter_h x filter_w)
, and then flattening that into a 2d array where each row is a vector of pixels in an image patch.
The dimensions of the 6-D col
(before being transposed and reshaped) are:
[sample, channel, y_position_within_patch, x_position_within_patch, y_patch_index, x_patch_index]
So for example col[n, c, :, :, i, j]
will be a 2-d image patch (the i'th-patch from the top, and j'th from the left in the grid of image patches).
After the transpose and reshape, col[n*c*i*j, :]
will refer to this same image patch, but with all the pixels flattened into a vector.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With