Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to optimize this MaxPool2d implementation

I made some implementations of MaxPool2d(Running correctly, comparing with a pytorch). When testing this on a mnist dataset, this function(updateOutput) takes a very long time to complete. How to optimize this code using numpy?

class MaxPool2d(Module):
    def __init__(self, kernel_size):
        super(MaxPool2d, self).__init__()
        self.kernel_size = kernel_size
        self.gradInput = None

    def updateOutput(self, input):
        #print("MaxPool updateOutput")
        #start_time = time.time()
        kernel = self.kernel_size
        poolH = input.shape[2] // kernel
        poolW = input.shape[3] // kernel
        self.output = np.zeros((input.shape[0], 
                                input.shape[1], 
                                poolH,
                                poolW))
        self.index = np.zeros((input.shape[0],
                                    input.shape[1],
                                    poolH,
                                    poolW,
                                    2), 
                                    dtype='int32')

        for i in range(input.shape[0]):
            for j in range(input.shape[1]):
                for k in range(0, input.shape[2] - kernel+1, kernel):
                    for m in range(0, input.shape[3] - kernel+1, kernel):
                        M = input[i, j, k : k+kernel, m : m+kernel]
                        self.output[i, j, k // kernel, m // kernel] = M.max()
                        self.index[i, j, k // kernel, m // kernel] = np.array(np.unravel_index(M.argmax(), M.shape)) + np.array((k, m))

        #print(f"time: {time.time() - start_time:.3f}s")
        return self.output

input shape = (batch_size, n_input_channels, h, w)

output shape = (batch_size, n_output_channels, h // kern_size, w // kern_size)

like image 860
annaFerdsf Avatar asked Apr 14 '26 02:04

annaFerdsf


1 Answers

For clarity I've simplified your example by removing batch size and channels dimensions. Most of time is spent on calculation of M.max(). I've created benchmark function update_output_b to do this loop with constant array of ones.

import time
import numpy as np

def timeit(cycles):
    def timed(func):
        def wrapper(*args, **kwargs):
            start_t = time.time()
            for _ in range(cycles):
                func(*args, **kwargs)
            t = (time.time() - start_t) / cycles
            print(f'{func.__name__} mean execution time: {t:.3f}s')

        return wrapper
    return timed

@timeit(100)
def update_output_b(input, kernel):
    ones = np.ones((kernel, kernel))

    pool_h = input.shape[0] // kernel
    pool_w = input.shape[1] // kernel
    output = np.zeros((pool_h, pool_w))

    for i in range(0, input.shape[0] - kernel + 1, kernel):
        for j in range(0, input.shape[1] - kernel + 1, kernel):
            output[i // kernel, j // kernel] = ones.max()

    return output

in_arr = np.random.rand(3001, 200)
update_output_b(in_arr, 3)

Its output is update_output_b mean execution time: 0.277s as it doesn't use numpy fully vectorized operations. When it is possible, you should always prefere native numpy functions over loops.

In addition, using slices of input array slow execution as access to continuous memory is in most cases faster.

@timeit(100)
def update_output_1(input, kernel):
    pool_h = input.shape[0] // kernel
    pool_w = input.shape[1] // kernel
    output = np.zeros((pool_h, pool_w))

    for i in range(0, input.shape[0] - kernel + 1, kernel):
        for j in range(0, input.shape[1] - kernel + 1, kernel):
            M = input[i : i + kernel, j : j + kernel]
            output[i // kernel, j // kernel] = M.max()

    return output

update_output_1(in_arr, 3)

Code returns update_output_1 mean execution time: 0.332s (+55ms comparing to previous one)

I've added vectorized code bellow. It works ~20x faster (update_output_2 mean execution time: 0.015s), however it is probably far from optimal.

@timeit(100)
def update_output_2(input, kernel):
    pool_h = input.shape[0] // kernel
    pool_w = input.shape[1] // kernel
    input_h = pool_h * kernel
    input_w = pool_w * kernel

    # crop input
    output = input[:input_h, :input_w]
    # calculate max along second axis
    output = output.reshape((-1, kernel))
    output = output.max(axis=1)
    # calculate max along first axis
    output = output.reshape((pool_h, kernel, pool_w))
    output = output.max(axis=1)

    return output

update_output_2(in_arr, 3)

It generates output in 3 steps:

  • Cropping input to size divisible by kernel
  • Calculating max along second axis (it reduce offsets between slices in first axis)
  • Calculating max along first axis

Edit:

I've added modifications for retrieving indexes of max values. However, you should check index arithmetics as I've only tested it on a random array.

It calculate output_indices along second axis in ech window and then uses output_indices_selector to select maximum along second one.

def update_output_3(input, kernel):
    pool_h = input.shape[0] // kernel
    pool_w = input.shape[1] // kernel
    input_h = pool_h * kernel
    input_w = pool_w * kernel

    # crop input
    output = input[:input_h, :input_w]

    # calculate max along second axis
    output_tmp = output.reshape((-1, kernel))
    output_indices = output_tmp.argmax(axis=1)
    output_indices += np.arange(output_indices.shape[0]) * kernel
    output_indices = np.unravel_index(output_indices, output.shape)
    output_tmp = output[output_indices]

    # calculate max along first axis
    output_tmp = output_tmp.reshape((pool_h, kernel, pool_w))
    output_indices_selector = (kernel * pool_w * np.arange(pool_h).reshape(pool_h, 1))
    output_indices_selector = output_indices_selector.repeat(pool_w, axis=1)
    output_indices_selector += pool_w * output_tmp.argmax(axis=1)
    output_indices_selector += np.arange(pool_w)
    output_indices_selector = output_indices_selector.flatten()

    output_indices = (output_indices[0][output_indices_selector],
                      output_indices[1][output_indices_selector])
    output = output[output_indices].reshape(pool_h, pool_w)

    return output, output_indices
like image 154
Jakub Gąsiewski Avatar answered Apr 15 '26 16:04

Jakub Gąsiewski



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!