Hi can someone improve this code ? The code is about Adaptive Median Filter. When working on large image the code is so slow. 
import numpy as np
def padding(img,pad):
    padded_img = np.zeros((img.shape[0]+2*pad,img.shape[1]+2*pad))
    padded_img[pad:-pad,pad:-pad] = img
    return padded_img
def AdaptiveMedianFilter(img,s=3,sMax=7):
    if len(img.shape) == 3:
        raise Exception ("Single channel image only")
    H,W = img.shape
    a = sMax//2
    padded_img = padding(img,a)
    f_img = np.zeros(padded_img.shape)
    for i in range(a,H+a+1):
        for j in range(a,W+a+1):
            value = Lvl_A(padded_img,i,j,s,sMax)
            f_img[i,j] = value
    return f_img[a:-a,a:-a] 
 def Lvl_A(mat,x,y,s,sMax):
    window = mat[x-(s//2):x+(s//2)+1,y-(s//2):y+(s//2)+1]
    Zmin = np.min(window)
    Zmed = np.median(window)
    Zmax = np.max(window)
    A1 = Zmed - Zmin
    A2 = Zmed - Zmax
    if A1 > 0 and A2 < 0:
        return Lvl_B(window)
    else:
        s += 2 
        if s <= sMax:
            return Lvl_A(mat,x,y,s,sMax)
        else:
             return Zmed
def Lvl_B(window):
     h,w = window.shape
    Zmin = np.min(window)
    Zmed = np.median(window)
    Zmax = np.max(window)
    Zxy = window[h//2,w//2]
    B1 = Zxy - Zmin
    B2 = Zxy - Zmax
    if B1 > 0 and B2 < 0 :
        return Zxy
    else:
        return Zmed
Is there any way to improve this code ? For example using vectorized sliding window ? I dont know how to use what numpy function. Ps: For boundary checking its using padding so it dont have to check for out of bounds.
The numba's njit is perfect for such kind of computation. Mixed with the parallel=True+prange it can be much faster. Moreover, you can pass the minimum, maximum and median values to Lvl_B rather than recomputing them as @CrisLuengo pointed out.
Here is the modified code:
import numpy as np
from numba import njit,prange
@njit
def padding(img,pad):
    padded_img = np.zeros((img.shape[0]+2*pad,img.shape[1]+2*pad))
    padded_img[pad:-pad,pad:-pad] = img
    return padded_img
@njit(parallel=True)
def AdaptiveMedianFilter(img,s=3,sMax=7):
    if len(img.shape) == 3:
        raise Exception ("Single channel image only")
    H,W = img.shape
    a = sMax//2
    padded_img = padding(img,a)
    f_img = np.zeros(padded_img.shape)
    for i in prange(a,H+a+1):
        for j in range(a,W+a+1):
            value = Lvl_A(padded_img,i,j,s,sMax)
            f_img[i,j] = value
    return f_img[a:-a,a:-a] 
@njit
def Lvl_A(mat,x,y,s,sMax):
    window = mat[x-(s//2):x+(s//2)+1,y-(s//2):y+(s//2)+1]
    Zmin = np.min(window)
    Zmed = np.median(window)
    Zmax = np.max(window)
    A1 = Zmed - Zmin
    A2 = Zmed - Zmax
    if A1 > 0 and A2 < 0:
        return Lvl_B(window, Zmin, Zmed, Zmax)
    else:
        s += 2 
        if s <= sMax:
            return Lvl_A(mat,x,y,s,sMax)
        else:
             return Zmed
@njit
def Lvl_B(window, Zmin, Zmed, Zmax):
    h,w = window.shape
    Zxy = window[h//2,w//2]
    B1 = Zxy - Zmin
    B2 = Zxy - Zmax
    if B1 > 0 and B2 < 0 :
        return Zxy
    else:
        return Zmed
This code is 500 times faster on my machine with a 256x256 random image.
Note that the first call will not be much faster due to the (included) compilation time.
Note also that the computation can be even faster by not recomputing the min/max/median for each value as the sliding windows share many values (see the paper constant time median filtering (Perreault et al, 2007)).
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With