Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Fastest way to compute image dataset channel wise mean and standard deviation in Python

I have a huge image dataset that does not fit in memory. I want to compute the mean and standard deviation, loading images from disk.

I'm currently trying to use this algorithm found on wikipedia.

# for a new value newValue, compute the new count, new mean, the new M2.
# mean accumulates the mean of the entire dataset
# M2 aggregates the squared distance from the mean
# count aggregates the amount of samples seen so far
def update(existingAggregate, newValue):
    (count, mean, M2) = existingAggregate
    count = count + 1 
    delta = newValue - mean
    mean = mean + delta / count
    delta2 = newValue - mean
    M2 = M2 + delta * delta2

    return existingAggregate

# retrieve the mean and variance from an aggregate
def finalize(existingAggregate):
    (count, mean, M2) = existingAggregate
    (mean, variance) = (mean, M2/(count - 1)) 
    if count < 2:
        return float('nan')
    else:
        return (mean, variance)

This is my current implementation (computing just for the red channel):

count = 0
mean = 0
delta = 0
delta2 = 0
M2 = 0
for i, file in enumerate(tqdm(first)):
    image = cv2.imread(file)
    for i in range(224):
        for j in range(224):
            r, g, b = image[i, j, :]
            newValue = r
            count = count + 1
            delta = newValue - mean
            mean = mean + delta / count
            delta2 = newValue - mean
            M2 = M2 + delta * delta2

print('first mean', mean)
print('first std', np.sqrt(M2 / (count - 1)))

This implementation works close enough on a subset of the dataset I tried.

The problem is that it is extremely slow and therefore nonviable.

  • Is there a standard way of doing this?

  • How can I adapt this for faster result or compute the RGB mean and standard deviation for all the dataset without loading it all in memory at the same time and at reasonable speed?

like image 614
Bruno Klein Avatar asked Dec 16 '17 21:12

Bruno Klein


People also ask

How do you find the mean and standard deviation of an image data set in Python?

mean: simply divide the sum of pixel values by the total count - number of pixels in the dataset computed as len(df) * image_size * image_size. standard deviation: use the following equation: total_std = sqrt(psum_sq / count - total_mean ** 2)

Why standard deviation is used in image processing?

The standard deviation (Σ) provides a measure of the dispersion of image gray level intensities and can be understood as a measure of the power level of the alternating signal component acquired by the camera.

How do you find the mean in image processing?

Mean value is the sum of pixel values divided by the total number of pixel values. Pixel Values Each of the pixels that represents an image stored inside a computer has a pixel value which describes how bright that pixel is, and/or what color it should be.

Which Python method finds standard deviation?

The statistics. stdev() method calculates the standard deviation from a sample of data. Standard deviation is a measure of how spread out the numbers are. A large standard deviation indicates that the data is spread out, - a small standard deviation indicates that the data is clustered closely around the mean.


2 Answers

Since this is a numerically heavy task (a lot of iterations around a matrix, or a tensor), I always suggest to use libraries that are good at this: numpy.

A properly installed numpy should be able to utilize the underlying BLAS (Basic Linear Algebra Subroutines) routines which are optimized for operating an array of floating points from the memory hierarchy perspective.

imread should already give you the numpy array. You can get the reshaped 1d array of the image of the red channel by

import numpy as np
val = np.reshape(image[:,:,0], -1)

the mean of such by

np.mean(val)

and the standard deviation by

np.std(val)

In this way, you can get rid of two layers of python loops:

count = 0
mean = 0
delta = 0
delta2 = 0
M2 = 0
for i, file in enumerate(tqdm(first)):
    image = cv2.imread(file)
        val = np.reshape(image[:,:,0], -1)
        img_mean = np.mean(val)
        img_std = np.std(val)
        ...

The rest of the incremental update should be straightforward.

Once you have done this, the bottleneck will become the image loading speed, which is limited by disk read operation performance. For that regard, I suspect using multi-thread as others suggested will help much based on my prior experience.

like image 162
Yo Hsiao Avatar answered Sep 18 '22 18:09

Yo Hsiao


You can use also opencv's method meanstddev.

cv2.meanStdDev(src[, mean[, stddev[, mask]]]) → mean, stddev
like image 42
Andrey Smorodov Avatar answered Sep 20 '22 18:09

Andrey Smorodov