I am trying to perform an image classification task using mini-imagenet dataset. The data that I want to use, contains a few bad data points(I am not sure why). I would like to load this data and train my model on it. In the process, I want to skip the bad data points completely. How do I do this? The data loader I am using is as follows:
class MiniImageNet(Dataset):
def __init__(self, root, train=True,
transform=None,
index_path=None, index=None, base_sess=None):
if train:
setname = 'train'
else:
setname = 'test'
self.root = os.path.expanduser(root)
self.transform = transform
self.train = train # training set or test set
self.IMAGE_PATH = os.path.join(root, 'miniimagenet/images')
self.SPLIT_PATH = os.path.join(root, 'miniimagenet/split')
csv_path = osp.join(self.SPLIT_PATH, setname + '.csv')
lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]
self.data = []
self.targets = []
self.data2label = {}
lb = -1
self.wnids = []
for l in lines:
name, wnid = l.split(',')
path = osp.join(self.IMAGE_PATH, name)
if wnid not in self.wnids:
self.wnids.append(wnid)
lb += 1
self.data.append(path)
self.targets.append(lb)
self.data2label[path] = lb
self.y = self.targets
if train:
image_size = 84
self.transform = transforms.Compose([
transforms.RandomResizedCrop(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
else:
image_size = 84
self.transform = transforms.Compose([
transforms.Resize([image_size, image_size]),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
def __len__(self):
return len(self.data)
def __getitem__(self, i):
path, targets = self.data[i], self.targets[i]
image = self.transform(Image.open(path).convert('RGB'))
return image, targets
I tried to use a try-except sequence, but in that case, instead of skipping, the dataloader is returning None, causing an error. How do I completely skip a datapoint in a dataloader?
You can do it directly inside the __getitem__ method of the Dataset class. Assume that initially the dataset contains integers from 1 to 10 but you want to skip the odd elemements, you can do it like this:
class SkipBadItems(Dataset):
def __init__(self) -> None:
super().__init__()
self.data = [1,2,3,4,5,6,7,8,9,10]
def __getitem__(self, index):
if index >= len(self.data):
raise StopIteration
item = self.data[index]
if item % 2 == 0:
return torch.tensor(item)
else:
del self.data[index]
return self.__getitem__(index)
def __len__(self):
return len(self.data)
dl = DataLoader(SkipBadItems(),batch_size=1,shuffle=True)
for e in range(10):
for x in dl:
print(x)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With