Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

img should be PIL Image. Got <class 'torch.Tensor'>

Tags:

python

pytorch

I'm trying to iterate through a loader to check if it's working, however the below error is given:

TypeError: img should be PIL Image. Got <class 'torch.Tensor'>

I've tried adding both transforms.ToTensor() and transforms.ToPILImage() and it gives me an error asking for the opposite. i.e, with ToPILImage(), it will ask for tensor, and vice versa.

# Imports here
%matplotlib inline
import matplotlib.pyplot as plt
from torch import nn, optim
import torch.nn.functional as F
import torch
from torchvision import transforms, datasets, models
import seaborn as sns
import pandas as pd
import numpy as np

data_dir = 'flowers'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'
test_dir = data_dir + '/test'

#Creating transform for training set
train_transforms = transforms.Compose(
[transforms.Resize(255), 
transforms.CenterCrop(224), 
transforms.ToTensor(), 
transforms.RandomHorizontalFlip(), 
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

#Creating transform for test set
test_transforms = transforms.Compose(
[transforms.Resize(255),
transforms.CenterCrop(224), 
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])

#transforming for all data
train_data = datasets.ImageFolder(train_dir, transform=train_transforms)
test_data = datasets.ImageFolder(test_dir, transform = test_transforms)
valid_data = datasets.ImageFolder(valid_dir, transform = test_transforms)

#Creating data loaders for test and training sets
trainloader = torch.utils.data.DataLoader(train_data, batch_size = 32, 
shuffle = True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=32)
images, labels = next(iter(trainloader))

It should allow me to simply see the image once I run plt.imshow(images[0]), if its working correctly.

like image 366
Hamza Usman Avatar asked Jul 17 '19 15:07

Hamza Usman


1 Answers

Just add transforms.ToPILImage() to convert into pil image and then it will work, example:

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(255),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
like image 54
Biplob Das Avatar answered Oct 22 '22 23:10

Biplob Das