Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch - How to use "toPILImage" correctly

Tags:

python

pytorch

I would like to know, whether I used toPILImage from torchvision correctly. I want to use it, to see how the images look after initial image transformations are applied to the dataset.

When I use it like in the code below, the image that comes up has weird colors like this one. The original image is a regular RGB image.

This is my code:

import os
import torch
from PIL import Image, ImageFont, ImageDraw
import torch.utils.data as data
import torchvision
from torchvision import transforms    
import matplotlib.pyplot as plt

# Image transformations
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
    )
transform_img = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    normalize ])

train_data = torchvision.datasets.ImageFolder(
    root='./train_cl/',
    transform=transform_img
    )
test_data = torchvision.datasets.ImageFolder(
    root='./test_named_cl/',
    transform=transform_img                                             
    )

train_data_loader = data.DataLoader(train_data,
    batch_size=4,
    shuffle=True,
    num_workers=4) #num_workers=args.nThreads)

test_data_loader = data.DataLoader(test_data,
    batch_size=32,
    shuffle=False,
    num_workers=4)        

# Open Image from dataset:
to_pil_image = transforms.ToPILImage()
my_img, _ = train_data[248]
results = to_pil_image(my_img)
results.show()

Edit:

I had to use .data on the Torch Variable to get the tensor. Also I needed to rescale the numpy array before transposing. I found a working solution here, but it doesn't always work well. How can I do this better?

for i, data in enumerate(train_data_loader, 0):
    img, labels = data
    img = Variable(img)
    break

image = img.data.cpu().numpy()[0]

# This worked for rescaling:
image = (1/(2*2.25)) * image + 0.5

# Both of these didn't work:
# image /= (image.max()/255.0)
# image *= (255.0/image.max())

image = np.transpose(image, (1,2,0))
plt.imshow(image)
plt.show() 
like image 956
kett Avatar asked Feb 28 '18 17:02

kett


People also ask

How do I load image data with PyTorch?

When it comes to loading image data with PyTorch, the ImageFolder class works very nicely, and if you are planning on collecting the image data yourself, I would suggest organizing the data so it can be easily accessed using the ImageFolder class. However, life isn’t always easy.

What does the topilimage() transform do?

The ToPILImage () transform converts a torch tensor to PIL image. The torchvision.transforms module provides many important transforms that can be used to perform different types of manipulations on the image data.

How to convert a torch tensor to PiL image?

How to convert a Torch Tensor to PIL image? The ToPILImage () transform converts a torch tensor to PIL image. The torchvision.transforms module provides many important transforms that can be used to perform different types of manipulations on the image data.

What NumPy arrays are supported by to_PiL_image?

Looking into the source code of to_pil_image, you can see that only numpy arrays of types np. {uint8, int16, uint32, float32} are supported. Show activity on this post. The above answer worked for me only with the following change:


2 Answers

You can use PIL image but you're not actually loading the data as you would normally.

Try something like this instead:

import numpy as np
import matplotlib.pyplot as plt

for img,labels in train_data_loader:
    # load a batch from train data
    break

# this converts it from GPU to CPU and selects first image
img = img.cpu().numpy()[0]
#convert image back to Height,Width,Channels
img = np.transpose(img, (1,2,0))
#show the image
plt.imshow(img)
plt.show()  

As an update (02-10-2021):

import torchvision.transforms.functional as F
# load the image (creating a random image as an example)
img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
pil_image = F.to_pil_image(img_data)

Alternatively

import torchvision.transforms as transforms
img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
pil_image = transforms.ToPILImage()(img_data)

The second form can be integrated with dataset loader in pytorch or called directly as so.

I added a modified to_pil_image here

essentially it does what I suggested back in 2018 but it is integrated into pytorch now.

like image 90
Steven Avatar answered Sep 18 '22 10:09

Steven


I would use something like this

# Open Image from dataset:
my_img, _ = train_data[248]
results = transforms.ToPILImage()(my_img)
results.show()
like image 39
SpeedOfSpin Avatar answered Sep 18 '22 10:09

SpeedOfSpin