Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

pytorch collate_fn reject sample and yield another

I have built a Dataset, where I'm doing various checks on the images I'm loading. I'm then passing this DataSet to a DataLoader.

In my DataSet class I'm returning the sample as None if a picture fails my checks and i have a custom collate_fn function which removes all Nones from the retrieved batch and returns the remaining valid samples.

However at this point the returned batch can be of varying size. Is there a way to tell the collate_fn to keep sourcing data until the batch size meets a certain length?

class DataSet():
     def __init__(self, example):
          # initialise dataset
          # load csv file and image directory
          self.example = example
     def __getitem__(self,idx):
          # load one sample
          # if image is too dark return None
          # else 
          # return one image and its equivalent label

dataset = Dataset(csv_file='../', image_dir='../../')

dataloader = DataLoader(dataset , batch_size=4,
                        shuffle=True, num_workers=1, collate_fn = my_collate )

def my_collate(batch): # batch size 4 [{tensor image, tensor label},{},{},{}] could return something like G = [None, {},{},{}]
    batch = list(filter (lambda x:x is not None, batch)) # this gets rid of nones in batch. For example above it would result to G = [{},{},{}]
    # I want len(G) = 4
    # so how to sample another dataset entry?
    return torch.utils.data.dataloader.default_collate(batch) 
like image 359
Brian Formento Avatar asked Sep 06 '19 02:09

Brian Formento


2 Answers

There are 2 hacks that can be used to sort out the problem, choose one way:

By using the original batch sample Fast option:

def my_collate(batch):
    len_batch = len(batch) # original batch length
    batch = list(filter (lambda x:x is not None, batch)) # filter out all the Nones
    if len_batch > len(batch): # if there are samples missing just use existing members, doesn't work if you reject every sample in a batch
        diff = len_batch - len(batch)
        for i in range(diff):
            batch = batch + batch[:diff]
    return torch.utils.data.dataloader.default_collate(batch)

Otherwise just load another sample from dataset at random Better option:

def my_collate(batch):
    len_batch = len(batch) # original batch length
    batch = list(filter (lambda x:x is not None, batch)) # filter out all the Nones
    if len_batch > len(batch): # source all the required samples from the original dataset at random
        diff = len_batch - len(batch)
        for i in range(diff):
            batch.append(dataset[np.random.randint(0, len(dataset))])

    return torch.utils.data.dataloader.default_collate(batch)
like image 183
Brian Formento Avatar answered Oct 03 '22 03:10

Brian Formento


This worked for me, because sometimes even those random values are None.

def my_collate(batch):
    len_batch = len(batch)
    batch = list(filter(lambda x: x is not None, batch))

    if len_batch > len(batch):                
        db_len = len(dataset)
        diff = len_batch - len(batch)
        while diff != 0:
            a = dataset[np.random.randint(0, db_len)]
            if a is None:                
                continue
            batch.append(a)
            diff -= 1

    return torch.utils.data.dataloader.default_collate(batch)
like image 45
Mavic More Avatar answered Oct 03 '22 01:10

Mavic More