Say I have a batch of images in the form of tensors with dimensions (B x C x W x H) where B is the batch size, C is the number of channels in the image, and W and H are the width and height of the image respectively. I'm looking to use the transforms.Normalize()
function to normalize my images with respect to the mean and standard deviation of the dataset across the C image channels, meaning that I want a resulting tensor in the form 1 x C. Is there a straightforward way to do this?
I tried torch.view(C, -1).mean(1)
and torch.view(C, -1).std(1)
but I get the error:
view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
Edit
After looking into how view()
works in PyTorch, I know realize why my approach doesn't work; however, I still can't figure out how to get the per-channel mean and standard deviation.
mean: simply divide the sum of pixel values by the total count - number of pixels in the dataset computed as len(df) * image_size * image_size. standard deviation: use the following equation: total_std = sqrt(psum_sq / count - total_mean ** 2)
flatten. Flattens input by reshaping it into a one-dimensional tensor. If start_dim or end_dim are passed, only dimensions starting with start_dim and ending with end_dim are flattened.
torch. stack (tensors, dim=0, *, out=None) → Tensor. Concatenates a sequence of tensors along a new dimension. All tensors need to be of the same size.
Note that variances add, not standard deviations. See detailed explanation here: https://apcentral.collegeboard.org/courses/ap-statistics/classroom-resources/why-variances-add-and-why-it-matters
Here is the modified code:
nimages = 0
mean = 0.0
var = 0.0
for i_batch, batch_target in enumerate(trainloader):
batch = batch_target[0]
# Rearrange batch to be the shape of [B, C, W * H]
batch = batch.view(batch.size(0), batch.size(1), -1)
# Update total number of images
nimages += batch.size(0)
# Compute mean and std here
mean += batch.mean(2).sum(0)
var += batch.var(2).sum(0)
mean /= nimages
var /= nimages
std = torch.sqrt(var)
print(mean)
print(std)
You just need to rearrange batch tensor in a right way: from [B, C, W, H]
to [B, C, W * H]
by:
batch = batch.view(batch.size(0), batch.size(1), -1)
Here is complete usage example on random data:
Code:
import torch
from torch.utils.data import TensorDataset, DataLoader
data = torch.randn(64, 3, 28, 28)
labels = torch.zeros(64, 1)
dataset = TensorDataset(data, labels)
loader = DataLoader(dataset, batch_size=8)
nimages = 0
mean = 0.
std = 0.
for batch, _ in loader:
# Rearrange batch to be the shape of [B, C, W * H]
batch = batch.view(batch.size(0), batch.size(1), -1)
# Update total number of images
nimages += batch.size(0)
# Compute mean and std here
mean += batch.mean(2).sum(0)
std += batch.std(2).sum(0)
# Final step
mean /= nimages
std /= nimages
print(mean)
print(std)
Output:
tensor([-0.0029, -0.0022, -0.0036])
tensor([0.9942, 0.9939, 0.9923])
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With