I'm trying to iterate through a loader to check if it's working, however the below error is given:
TypeError: img should be PIL Image. Got <class 'torch.Tensor'>
I've tried adding both transforms.ToTensor()
and transforms.ToPILImage()
and it gives me an error asking for the opposite. i.e, with ToPILImage()
, it will ask for tensor, and vice versa.
# Imports here
%matplotlib inline
import matplotlib.pyplot as plt
from torch import nn, optim
import torch.nn.functional as F
import torch
from torchvision import transforms, datasets, models
import seaborn as sns
import pandas as pd
import numpy as np
data_dir = 'flowers'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'
test_dir = data_dir + '/test'
#Creating transform for training set
train_transforms = transforms.Compose(
[transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.RandomHorizontalFlip(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
#Creating transform for test set
test_transforms = transforms.Compose(
[transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
#transforming for all data
train_data = datasets.ImageFolder(train_dir, transform=train_transforms)
test_data = datasets.ImageFolder(test_dir, transform = test_transforms)
valid_data = datasets.ImageFolder(valid_dir, transform = test_transforms)
#Creating data loaders for test and training sets
trainloader = torch.utils.data.DataLoader(train_data, batch_size = 32,
shuffle = True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=32)
images, labels = next(iter(trainloader))
It should allow me to simply see the image once I run plt.imshow(images[0])
, if its working correctly.
Just add transforms.ToPILImage()
to convert into pil image and then it will work, example:
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.RandomHorizontalFlip(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With