Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to convert RGB images to grayscale in PyTorch dataloader?

Tags:

python

pytorch

I've downloaded some sample images from the MNIST dataset in .jpg format. Now I'm loading those images for testing my pre-trained model.

# transforms to apply to the data
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# MNIST dataset
test_dataset = dataset.ImageFolder(root=DATA_PATH, transform=trans)

# Data loader
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

Here DATA_PATH contains a subfolder with the sample image.

Here's my network definition

# Convolutional neural network (two convolutional layers)
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.network2D = nn.Sequential(
           nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
           nn.ReLU(),
           nn.MaxPool2d(kernel_size=2, stride=2),
           nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
           nn.ReLU(),
           nn.MaxPool2d(kernel_size=2, stride=2))
        self.network1D = nn.Sequential(
           nn.Dropout(),
           nn.Linear(7 * 7 * 64, 1000),
           nn.Linear(1000, 10))

    def forward(self, x):
        out = self.network2D(x)
        out = out.reshape(out.size(0), -1)
        out = self.network1D(out)
        return out

And this is my inference part

# Test the model
model = torch.load("mnist_weights_5.pth.tar")
model.eval()

for images, labels in test_loader:
   outputs = model(images.cuda())

When I run this code, I get the following error:

RuntimeError: Given groups=1, weight of size [32, 1, 5, 5], expected input[1, 3, 28, 28] to have 1 channels, but got 3 channels instead

I understand that the images are getting loaded as 3 channels (RGB). So how do I convert them to single channel in the dataloader?

Update: I changed transforms to include Grayscale option

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), transforms.Grayscale(num_output_channels=1)])

But now I get this error

TypeError: img should be PIL Image. Got <class 'torch.Tensor'>
like image 597
Harsh Wardhan Avatar asked Sep 21 '18 08:09

Harsh Wardhan


People also ask

How do I convert an image to grayscale in Pytorch?

torchvision. transforms. grayscale() method is used to convert an image to grayscale. If the input image is torch Tensor then it is expected to have [3, H, W] shape, H, W is height and width respectively.

How do I convert a RGB image to grayscale?

Average method is the most simple one. You just have to take the average of three colors. Since its an RGB image, so it means that you have add r with g with b and then divide it by 3 to get your desired grayscale image. Its done in this way.


2 Answers

When using ImageFolder class and with no custom loader, pytorch uses PIL to load image and converts it to RGB. Default Loader if torchvision image backend is PIL:

def pil_loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

You can use torchvision's Grayscale function in transforms. It will convert the 3 channel RGB image into 1 channel grayscale. Find out more about this at here

A sample code is below,

import torchvision as tv
import numpy as np
import torch.utils.data as data
dataDir         = 'D:\\general\\ML_DL\\datasets\\CIFAR'
trainTransform  = tv.transforms.Compose([tv.transforms.Grayscale(num_output_channels=1),
                                    tv.transforms.ToTensor(), 
                                    tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainSet        = tv.datasets.CIFAR10(dataDir, train=True, download=False, transform=trainTransform)
dataloader      = data.DataLoader(trainSet, batch_size=1, shuffle=False, num_workers=0)
images, labels  = iter(dataloader).next()
print (images.size())
like image 132
Thiagarajan Avatar answered Sep 28 '22 19:09

Thiagarajan


You may implement Dataloader not from ImageFolder, but from Datagenerator, directly load images in __getitem__ function. PIL.Image.open("..") then grayscale, to numpy and to Tensor.

Another option is to calculate greyscale(Y) channel from RGB by formula Y = 0.299 R + 0.587 G + 0.114 B. Slice array and convert to one channel.

But how do you train your model? usually train and test data loads in same way.

like image 38
chro Avatar answered Sep 28 '22 20:09

chro