I am working on an image classifier with 31 classes(Office dataset). There is one folder for each of the classes. I have a python script written using PyTorch that loads the dataset using datasets.ImageFolder
and assigns a label to each image and then trains. Here is my code snippet for loading data:
from torchvision import datasets, transforms
import torch
def load_training(root_path, dir, batch_size, kwargs):
transform = transforms.Compose(
[transforms.Resize([256, 256]),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()])
data = datasets.ImageFolder(root=root_path + dir, transform=transform)
train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)
return train_loader
The code takes each folder, assigns the same label to all images in that folder. Is there any way to find which label is assigned to which image/image folder?
The class ImageFolder has an attribute class_to_idx
which is a dictionary mapping the name of the class to the index (label). So, you can access the classes with data.classes
and for each class get the label with data.class_to_idx
.
For reference: https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With