Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch - Incorrect labeling using torchvision.datasets.ImageFolder

I have structured my dataset in the following way:

dataset/train/0/456.jpg
dataset/train/1/456456.jpg
dataset/train/2/456.jpg
dataset/train/...

dataset/val/0/878.jpg
dataset/val/1/234.jpg
dataset/val/2/34554.jpg
dataset/val/...

So I used torchvision.datasets.ImageFolder to import my dataset to PyTorch. However, it seems like it is not giving the right label to the right image. I've added my code below:

data_transforms = {
    'train': transforms.Compose(
        [transforms.Resize((176,176)),
         transforms.RandomRotation((0,360)),
         transforms.RandomHorizontalFlip(),
         transforms.RandomVerticalFlip(),
         transforms.CenterCrop(128),         
         transforms.Grayscale(),
         transforms.ToTensor(),
         transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
    ]),
    'val': transforms.Compose(
        [transforms.Resize((128,128)),
         transforms.Grayscale(),
         transforms.ToTensor(),
         transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
    ]),
}

data_dir = 'dataset'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

I found out that the labels are wrong using the following function:

def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

dataiter = iter(dataloaders['val'])
images, labels = dataiter.next()

imshow(torchvision.utils.make_grid(images))
print(labels)

Using the shown images and the labels, I manually checked whether they are correct. Unfortunately, the labels do not correspond to the images. Can someone tell me what I'm doing wrong?

like image 489
Zarif Avatar asked Jan 01 '23 12:01

Zarif


2 Answers

Someone helped me out with this. ImageFolder creates its own internal labels. By printing image_datasets['train'].class_to_idx you can see what label is paired to what internal label. Using this dictionary, you can trace back the original label.

like image 128
Zarif Avatar answered Jan 04 '23 01:01

Zarif


The ImageFolder API assumes that your data is in a "predefined" folder structure. Please check the below comment from PyTorch code or documentation @ https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder

A generic data loader where the images are arranged in this way: ::

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

This means, you need to arrange your data under folders matching with your labels. In the above case there are 2 labels, cats & dogs.

Hope this helps!

like image 23
Mohana Rao Avatar answered Jan 04 '23 03:01

Mohana Rao