Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Efficiently Standardizing Images in a Numpy Array

I have a numpy array of images of shape (N, H, W, C) where N is the number of images, H the image height, W the image width and C the RGB channels.

I would like to standardize my images channel-wise, so for each image I would like to channel-wise subtract the image channel's mean and divide by its standard deviation.

I did this in a loop, which worked, however it is very inefficient and as it makes a copy my RAM is getting too full.

def standardize(img):
    mean = np.mean(img)
    std = np.std(img)
    img = (img - mean) / std
    return img

for img in rgb_images:
    r_channel = standardize(img[:,:,0])
    g_channel = standardize(img[:,:,1])
    b_channel = standardize(img[:,:,2])
    normalized_image = np.stack([r_channel, g_channel, b_channel], axis=-1)
    standardized_images.append(normalized_image)
standardized_images = np.array(standardized_images)

How can I do this more efficiently making use of numpy's capabilities?

like image 752
Chris Avatar asked May 08 '18 17:05

Chris


1 Answers

Perform the ufunc reductions (mean, std) along the second and third axes, while keeping the dims intact that help in broadcasting later on with the division step -

mean = np.mean(rgb_images, axis=(1,2), keepdims=True)
std = np.std(rgb_images, axis=(1,2), keepdims=True)
standardized_images_out = (rgb_images - mean) / std

Boost the performance further by re-using the average values to compute standard-deviation, according to its formula and hence inspired by this solution , like so -

std = np.sqrt(((rgb_images - mean)**2).mean((1,2), keepdims=True))

Packaging into a function with the axes for reductions as a parameter, we would have -

from __future__ import division

def normalize_meanstd(a, axis=None): 
    # axis param denotes axes along which mean & std reductions are to be performed
    mean = np.mean(a, axis=axis, keepdims=True)
    std = np.sqrt(((a - mean)**2).mean(axis=axis, keepdims=True))
    return (a - mean) / std

standardized_images = normalize_meanstd(rgb_images, axis=(1,2))
like image 126
Divakar Avatar answered Oct 06 '22 00:10

Divakar