Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch: Image label

I am working on an image classifier with 31 classes(Office dataset). There is one folder for each of the classes. I have a python script written using PyTorch that loads the dataset using datasets.ImageFolder and assigns a label to each image and then trains. Here is my code snippet for loading data:

from torchvision import datasets, transforms
import torch

def load_training(root_path, dir, batch_size, kwargs):
    transform = transforms.Compose(
        [transforms.Resize([256, 256]),
         transforms.RandomCrop(224),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor()])
    data = datasets.ImageFolder(root=root_path + dir, transform=transform)
    train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)
    return train_loader

The code takes each folder, assigns the same label to all images in that folder. Is there any way to find which label is assigned to which image/image folder?

like image 496
tahsin314 Avatar asked Aug 18 '18 07:08

tahsin314


1 Answers

The class ImageFolder has an attribute class_to_idx which is a dictionary mapping the name of the class to the index (label). So, you can access the classes with data.classes and for each class get the label with data.class_to_idx.

For reference: https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py

like image 179
Jan Avatar answered Oct 21 '22 06:10

Jan