Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TypeError: tensor is not a torch image

Tags:

python

pytorch

While working through the AI course at Udacity I came across this error during the Transfer Learning section. Here is the code that seems to be causing the trouble:

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models

data_dir = 'filename'

# TODO: Define transforms for the training data and testing data
train_transforms= transforms.Compose([transforms.Resize((224,224)), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.ToTensor()])
test_transforms= transforms.Compose([transforms.Resize((224,224)), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.ToTensor()])

# Pass transforms in here, then run the next cell to see how the transforms look
train_data = datasets.ImageFolder(data_dir + '/train', transform=train_transforms)
test_data = datasets.ImageFolder(data_dir + '/test', transform=test_transforms)

trainloader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=32)
like image 797
Tristan Newman Avatar asked Aug 12 '18 08:08

Tristan Newman


3 Answers

The problem is with the order of the transforms. The ToTensor transform should come before the Normalize transform, since the latter expects a tensor, but the Resize transform returns an image. Correct code with the faulty lines changed:

train_transforms = transforms.Compose([
    transforms.Resize((224,224)), 
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
test_transforms = transforms.Compose([
    transforms.Resize((224,224)), 
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
like image 200
Agost Biro Avatar answered Nov 19 '22 11:11

Agost Biro


Another, less elegant solution (assuming the image was loaded with opencv and is hence BGR):

t_ = transforms.Compose([transforms.ToPILImage(),
                         transforms.Resize((224,224)),
                         transforms.ToTensor()])

norm_ = transforms.Normalize([103.939, 116.779, 123.68],[1,1,1])
img = 255*t_(img)
img = norm_(img)
like image 27
Alex Avatar answered Nov 19 '22 11:11

Alex


One reason for this error is that transforms.Normalize only accept 3d data (3, 224, 224). Here is an example code:

# imagenet normalize

from torchvision.transforms import Normalize

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
normalize = Normalize(mean, std)


img = np.random.choice(255, (10, 224, 224, 3))
img = img/255 # [0, 1]
img = torch.tensor(img, device=device).float().permute(0, 3, 1, 2)
img = normalize(img)

This will throw the error since the input data has 4d shape. If you change the code to this, then the error will disappear.

img = np.random.choice(255, (224, 224, 3))
like image 24
Shark Deng Avatar answered Nov 19 '22 11:11

Shark Deng