Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python: The implementation of im2col which takes the advantages of 6 dimensional array?

I'm reading an implementation of im2col from a deep learning book(At chapter 7, CNN), which its purpose is to transform a 4 dimensional array into 2 dimensional. I don't know why there is a 6 dimensional array in the implementation. I'm very interested about what's the idea behind the algorithm the author used.

I've tried to search many papers of the implementation of im2col, but none of them using high dimensional array like this. The currently materials I found useful for visualization of the process of im2col is the picture of this paper - HAL Id: inria-00112631


def im2col(input_data, filter_h, filter_w, stride=1, pad=0):
    """
    Parameters
    ----------
    input_data : (batch size, channel, height, width), or (N,C,H,W) at below
    filter_h : kernel height
    filter_w : kernel width
    stride : size of stride
    pad : size of padding
    Returns
    -------
    col : two dimensional array
    """
    N, C, H, W = input_data.shape
    out_h = (H + 2*pad - filter_h)//stride + 1
    out_w = (W + 2*pad - filter_w)//stride + 1

    img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
    col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))

    for y in range(filter_h):
        y_max = y + stride*out_h
        for x in range(filter_w):
            x_max = x + stride*out_w
            col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]

    col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
    return col
like image 934
Rainning Avatar asked May 11 '18 12:05

Rainning


2 Answers

Let's try and visualize what im2col does. It takes as its input a stack of color images, the stack has dimensions image id, color channel, vertical position, horizontal position. Let's for simplicity assume we have only one image:

enter image description here

First thing it does is padding:

enter image description here

Next, it cuts it up into windows. The size of the windows is controlled by filter_h/w, the overlap by strides.

enter image description here

This is where the six dimensions are coming from: image id (missing in the example because we have only one image), grid height/width, color channel. window height/width.

enter image description here

The algorithm as it stands is a bit clumsy, it assembles the output in the wrong order of dimensions and then has to correct that using transpose.

Better to get it right in the first place:

def im2col_better(input_data, filter_h, filter_w, stride=1, pad=0):
    img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
    N, C, H, W = img.shape
    out_h = (H - filter_h)//stride + 1
    out_w = (W - filter_w)//stride + 1
    col = np.zeros((N, out_h, out_w, C, filter_h, filter_w))
    for y in range(out_h):
        for x in range(out_w):
            col[:, y, x] = img[
                ..., y*stride:y*stride+filter_h, x*stride:x*stride+filter_w]
    return col.reshape(np.multiply.reduceat(col.shape, (0, 3)))

Just as a side note: We can do even better using stride_tricks and avoid the nested for loop:

def im2col_best(input_data, filter_h, filter_w, stride=1, pad=0):
    img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
    N, C, H, W = img.shape
    NN, CC, HH, WW = img.strides
    out_h = (H - filter_h)//stride + 1
    out_w = (W - filter_w)//stride + 1
    col = np.lib.stride_tricks.as_strided(img, (N, out_h, out_w, C, filter_h, filter_w), (NN, stride * HH, stride * WW, CC, HH, WW)).astype(float)
    return col.reshape(np.multiply.reduceat(col.shape, (0, 3)))

The last thing the algorithm does is reshaping, merging the first three dimensions (again only two in our example because only one image). The red arrows show how the individual windows are lined up into the first new dimension:

enter image description here

The last three dimensions color channel, y coordinate in window, x coordinate in window are merged into the second output dimension. Individual pixels are lined up as indicated by the yellow arrows:

enter image description here

like image 138
Paul Panzer Avatar answered Sep 30 '22 12:09

Paul Panzer


It looks like this function is just rearranging each of the C colour-channels in each of the N images into a (out_h x out_w) grid of overlapping image patches of size (filter_h x filter_w), and then flattening that into a 2d array where each row is a vector of pixels in an image patch.

The dimensions of the 6-D col (before being transposed and reshaped) are:

[sample, channel, y_position_within_patch, x_position_within_patch, y_patch_index, x_patch_index]

So for example col[n, c, :, :, i, j] will be a 2-d image patch (the i'th-patch from the top, and j'th from the left in the grid of image patches).

After the transpose and reshape, col[n*c*i*j, :] will refer to this same image patch, but with all the pixels flattened into a vector.

like image 22
Peter Avatar answered Sep 30 '22 12:09

Peter