Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do I load multiple grayscale images as a single tensor in pytorch?

I'm currently trying to use a stack a set of images as a single entity for each label to train a CNN on using cross-validation. Given a dataset of 224x224x1 grayscale images sorted by:

Root/
    Class0/image0_view0.png
    Class0/image0_view1.png
    Class0/image0_view2.png
    ...
    Class1/image0_view0.png
    Class1/image0_view1.png
    Class1/image0_view2.png

How would I go about flowing 3 images (view 0, 1, and 2) as a single tensor with dimensions 224x224x3 (3 grayscale images)? In other words, how would I create a dataset of image stacks in pytorch using ImageFolder/DatasetFolder and DataLoader? Would I have to re-organize my folders and classes, or would it be easier to make the stacks when I make the splits for cross-validation?

Thank you for your time and help! Let me know if I can provide any more info.

like image 957
jinsom Avatar asked Nov 25 '22 23:11

jinsom


1 Answers

I had a very similar task. I needed to load a random sequence of 3 images as an element of a batch for training the network not on separate images but on seq of images. For batch size 8, I have 8 x 3 = 24 images. This seems to be very similar to different views in your case. I used imread_collection functionality from skimage.io. I added such a getitem to the Dataset class:

def __getitem__(self, idx):
    idx_q = int(torch.randint(0 + self.boundary, self.length - self.boundary, (1,))) 
    
    q = imread_collection([self.image_paths[idx_q-1], self.image_paths[idx_q], self.image_paths[idx_q+1]], conserve_memory=True)
            
    if self.transform:
        q = torch.stack([self.transform(img) for img in q])

    return q, p, n

Here I generate a random index of an image and then load three consecutive images using imread_collection and self.image_paths, which is the list with paths to all images. Then I do transform of each image and stack them. In your case, you should think about using the right indexes, maybe by applying a sliding window on the length of self.image_paths.

A bit more info could be found on the torch forum. I also tried to ask and find a more elegant solution, but couldn't and successfully trained the model with such an approach.

like image 137
dinarkino Avatar answered Nov 28 '22 11:11

dinarkino