Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Finding mean and standard deviation across image channels PyTorch

Say I have a batch of images in the form of tensors with dimensions (B x C x W x H) where B is the batch size, C is the number of channels in the image, and W and H are the width and height of the image respectively. I'm looking to use the transforms.Normalize() function to normalize my images with respect to the mean and standard deviation of the dataset across the C image channels, meaning that I want a resulting tensor in the form 1 x C. Is there a straightforward way to do this?

I tried torch.view(C, -1).mean(1) and torch.view(C, -1).std(1) but I get the error:

view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

Edit

After looking into how view() works in PyTorch, I know realize why my approach doesn't work; however, I still can't figure out how to get the per-channel mean and standard deviation.

like image 779
ch1maera Avatar asked Feb 06 '20 18:02

ch1maera


People also ask

How do you find the mean and standard deviation of an image in Python?

mean: simply divide the sum of pixel values by the total count - number of pixels in the dataset computed as len(df) * image_size * image_size. standard deviation: use the following equation: total_std = sqrt(psum_sq / count - total_mean ** 2)

How do I flatten in PyTorch?

flatten. Flattens input by reshaping it into a one-dimensional tensor. If start_dim or end_dim are passed, only dimensions starting with start_dim and ending with end_dim are flattened.

What is Torch stack?

torch. stack (tensors, dim=0, *, out=None) → Tensor. Concatenates a sequence of tensors along a new dimension. All tensors need to be of the same size.


2 Answers

Note that variances add, not standard deviations. See detailed explanation here: https://apcentral.collegeboard.org/courses/ap-statistics/classroom-resources/why-variances-add-and-why-it-matters

Here is the modified code:

nimages = 0
mean = 0.0
var = 0.0
for i_batch, batch_target in enumerate(trainloader):
    batch = batch_target[0]
    # Rearrange batch to be the shape of [B, C, W * H]
    batch = batch.view(batch.size(0), batch.size(1), -1)
    # Update total number of images
    nimages += batch.size(0)
    # Compute mean and std here
    mean += batch.mean(2).sum(0) 
    var += batch.var(2).sum(0)

mean /= nimages
var /= nimages
std = torch.sqrt(var)

print(mean)
print(std)
like image 64
debadeepta Avatar answered Sep 21 '22 00:09

debadeepta


You just need to rearrange batch tensor in a right way: from [B, C, W, H] to [B, C, W * H] by:

batch = batch.view(batch.size(0), batch.size(1), -1)

Here is complete usage example on random data:

Code:

import torch
from torch.utils.data import TensorDataset, DataLoader

data = torch.randn(64, 3, 28, 28)
labels = torch.zeros(64, 1)
dataset = TensorDataset(data, labels)
loader = DataLoader(dataset, batch_size=8)

nimages = 0
mean = 0.
std = 0.
for batch, _ in loader:
    # Rearrange batch to be the shape of [B, C, W * H]
    batch = batch.view(batch.size(0), batch.size(1), -1)
    # Update total number of images
    nimages += batch.size(0)
    # Compute mean and std here
    mean += batch.mean(2).sum(0) 
    std += batch.std(2).sum(0)

# Final step
mean /= nimages
std /= nimages

print(mean)
print(std)

Output:

tensor([-0.0029, -0.0022, -0.0036])
tensor([0.9942, 0.9939, 0.9923])
like image 39
trsvchn Avatar answered Sep 19 '22 00:09

trsvchn