This is the error i get when I try to train my network.
The class we used to store Images from the Caltech 101 dataset was provided us by our teachers.
from torchvision.datasets import VisionDataset
from PIL import Image
import os
import os.path
import sys
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
class Caltech(VisionDataset):
def __init__(self, root, split='train', transform=None, target_transform=None):
super(Caltech, self).__init__(root, transform=transform, target_transform=target_transform)
self.split = split # This defines the split you are going to use
# (split files are called 'train.txt' and 'test.txt')
'''
- Here you should implement the logic for reading the splits files and accessing elements
- If the RAM size allows it, it is faster to store all data in memory
- PyTorch Dataset classes use indexes to read elements
- You should provide a way for the __getitem__ method to access the image-label pair
through the index
- Labels should start from 0, so for Caltech you will have lables 0...100 (excluding the background class)
'''
# Open file in read only mode and read all lines
file = open(self.split, "r")
lines = file.readlines()
# Filter out the lines which start with 'BACKGROUND_Google' as asked in the homework
self.elements = [i for i in lines if not i.startswith('BACKGROUND_Google')]
# Delete BACKGROUND_Google class from dataset labels
self.classes = sorted(os.listdir(os.path.join(self.root, "")))
self.classes.remove("BACKGROUND_Google")
def __getitem__(self, index):
'''
__getitem__ should access an element through its index
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
'''
img = Image.open(os.path.join(self.root, self.elements[index].rstrip()))
target = self.classes.index(self.elements[index].rstrip().split('/')[0])
image, label = img, target # Provide a way to access image and label via index
# Image should be a PIL Image
# label can be int
# Applies preprocessing when accessing the image
if self.transform is not None:
image = self.transform(image)
return image, label
def __len__(self):
'''
The __len__ method returns the length of the dataset
It is mandatory, as this is used by several other components
'''
# Provides a way to get the length (number of elements) of the dataset
length = len(self.elements)
return length
Whereas the preprocessing phase is done by this code:
# Define transforms for training phase
train_transform = transforms.Compose([transforms.Resize(256), # Resizes short size of the PIL image to 256
transforms.CenterCrop(224), # Crops a central square patch of the image
# 224 because torchvision's AlexNet needs a 224x224 input!
# Remember this when applying different transformations, otherwise you get an error
transforms.ToTensor(), # Turn PIL Image to torch.Tensor
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalizes tensor with mean and standard deviation
])
# Define transforms for the evaluation phase
eval_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
In the end this is the preparation of datasets and dataloader:
# Clone github repository with data
if not os.path.isdir('./Homework2-Caltech101'):
!git clone https://github.com/MachineLearning2020/Homework2-Caltech101.git
# Commands to execute when there is an error saying no file or directory related to ./Homework2-Caltech101/
# !rm -r ./Homework2-Caltech101/
# !git clone https://github.com/MachineLearning2020/Homework2-Caltech101.git
DATA_DIR = 'Homework2-Caltech101/101_ObjectCategories'
SPLIT_TRAIN = 'Homework2-Caltech101/train.txt'
SPLIT_TEST = 'Homework2-Caltech101/test.txt'
# 1 - Data preparation
myTrainDS = Caltech(DATA_DIR, split = SPLIT_TRAIN, transform=train_transform)
myTestDS = Caltech(DATA_DIR, split = SPLIT_TEST, transform=eval_transform)
print('My Train DS: {}'.format(len(myTrainDS)))
print('My Test DS: {}'.format(len(myTestDS)))
# 1 - Data preparation
myTrain_dataloader = DataLoader(myTrainDS, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)
myTest_dataloader = DataLoader(myTestDS, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
Okay now the two .txt files contain the lists of images we want to have in the train and test splits, so we have to get them from there, but that should have been done correctly. The thing is that when I approach my training phase (see code later) I am presented the error in the title. I already tried to add the following line in the transform function:
[...]
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
after the centercrop, but it says that Image has no attribute repeat, so I'm kinda stuck.
The training code line which gives me the error is the following:
# Iterate over the dataset
for images, labels in myTrain_dataloader:
If needed, full error is:
RuntimeError Traceback (most recent call last)
<ipython-input-197-0e4710a9855d> in <module>()
47
48 # Iterate over the dataset
---> 49 for images, labels in myTrain_dataloader:
50
51 # Bring data over the device of choice
2 frames
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
817 else:
818 del self._task_info[idx]
--> 819 return self._process_data(data)
820
821 next = __next__ # Python 2 compatibility
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data)
844 self._try_put_index()
845 if isinstance(data, ExceptionWrapper):
--> 846 data.reraise()
847 return data
848
/usr/local/lib/python3.6/dist-packages/torch/_utils.py in reraise(self)
383 # (https://bugs.python.org/issue2651), so we work around it.
384 msg = KeyErrorMessage(msg)
--> 385 raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "<ipython-input-180-0b00b175e18c>", line 72, in __getitem__
image = self.transform(image)
File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py", line 70, in __call__
img = t(img)
File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py", line 175, in __call__
return F.normalize(tensor, self.mean, self.std, self.inplace)
File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py", line 217, in normalize
tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
RuntimeError: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]
I'm using Alexnet and the code I was provided is the following:
net = alexnet() # Loading AlexNet model
# AlexNet has 1000 output neurons, corresponding to the 1000 ImageNet's classes
# We need 101 outputs for Caltech-101
net.classifier[6] = nn.Linear(4096, NUM_CLASSES) # nn.Linear in pytorch is a fully connected layer
# The convolutional layer is nn.Conv2d
# We just changed the last layer of AlexNet with a new fully connected layer with 101 outputs
# It is mandatory to study torchvision.models.alexnet source code
The first dimension of the tensor means the color, so what your error means is that you are giving a grayscale picture (1 channel), while the data loader expects a RGB image (3 channels). You defined a pil_loader function that returns an image in RGB, but you are never using it.
So you have two options:
Work with the image in Grayscale instead of rgb, which is cheaper computationally speaking.
Solution: Both in train and test transforms change transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
for transforms.Normalize((0.5), (0.5))
Make sure your image is in rgb. I don't know how your images are stored, but I guess you downloaded the dataset in grayscale. One thing you could try is using the pil_loader function you defines. Try changing img = Image.open(os.path.join(self.root, self.elements[index].rstrip()))
for img = pil_loader(os.path.join(self.root, self.elements[index].rstrip()))
in yout __getitem__
function.
Let me know how it goes!
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With