Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of dimension: 4

I am having a hard time understanding image segmentation. I have implemented Unet model for image segmentation. I am using PASCAL VOC dataset and I am trying to train my model. However, I got stuck when calculating the loss. I am unsure of what should be the expected shapes of the output and target classes. Can someone please educate me on what I am doing wrong? My only guess is that I am missing something when it comes to the ground truth images since I don't know how the model will learn which class is which. Thank!

Here is my Unet class:

import torch
import torch.nn as nn
from torchvision import transforms


def x2conv(in_channels, out_channels):
    double_conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0),
        nn.ReLU(inplace=True))
    return double_conv


class Encoder(nn.Module):
    def __init__(self, chs):
        super().__init__()
        self.enc_blocks = nn.ModuleList(
            [x2conv(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        ftrs = []
        for block in self.enc_blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        return ftrs


class Decoder(nn.Module):
    def __init__(self, chs):
        super().__init__()
        self.chs = chs
        self.upconvs = nn.ModuleList(
            [nn.ConvTranspose2d(chs[i], chs[i+1], kernel_size=2, stride=2) for i in range(len(chs)-1)])
        self.dec_blocks = nn.ModuleList(
            [x2conv(chs[i], chs[i+1]) for i in range(len(chs)-1)])

    def forward(self, x, encoder_features):
        for i in range(len(self.chs)-1):
            x = self.upconvs[i](x)
            enc_ftrs = self.crop(encoder_features[i], x)
            x = torch.cat([x, enc_ftrs], dim=1)
            x = self.dec_blocks[i](x)
        return x

    def crop(self, enc_ftrs, x):
        _, _, H, W = x.shape
        enc_ftrs = transforms.CenterCrop([H, W])(enc_ftrs)
        return enc_ftrs


class UNet(nn.Module):
    def __init__(self, enc_chs, dec_chs, num_class):
        super(UNet, self).__init__()

        self.encoder = Encoder(enc_chs)
        self.decoder = Decoder(dec_chs)
        self.softmax = nn.Conv2d(dec_chs[-1], num_class, kernel_size=1)

    def forward(self, x):
        enc_ftrs = self.encoder(x)
        out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        out = self.softmax(out)
        return out

And here is my dataset class:

from PIL import Image
import torchvision
VOC_CLASSES = [   # How to use?
    "background",
    "aeroplane",
    "bicycle",
    "bird",
    "boat",
    "bottle",
    "bus",
    "car",
    "cat",
    "chair",
    "cow",
    "diningtable",
    "dog",
    "horse",
    "motorbike",
    "person",
    "pottedplant",
    "sheep",
    "sofa",
    "train",
    "tvmonitor",
]


VOC_COLORMAP = [  # How to use?
    [0, 0, 0],  # Background
    [128, 0, 0],  # Aeroplane
    [0, 128, 0],  # Bicycle
    [128, 128, 0],  # Bird
    [0, 0, 128],  # Boat
    [128, 0, 128],  # Bottle
    [0, 128, 128],  # Bus
    [128, 128, 128],  # Car
    [64, 0, 0],  # Cat
    [192, 0, 0],  # Chair
    [64, 128, 0],  # Cow
    [192, 128, 0],  # Diningtable
    [64, 0, 128],  # Dog
    [192, 0, 128],  # Horse
    [64, 128, 128],  # Motorbike
    [192, 128, 128],  # Person
    [0, 64, 0],  # Pottedplant
    [128, 64, 0],  # Sheep
    [0, 192, 0],  # Sofa
    [128, 192, 0],  # Train
    [0, 64, 128],  # tvmonitor
]


class VocDataset(torchvision.datasets.VOCSegmentation):
    def __init__(self, image_set, transform, root="../data/VOCtrainval_11-May-2012/", download=False, year="2012"):
        self.transform = transform
        self.year = year
        super().__init__(root=root, image_set=image_set,
                         download=download, transform=transform, year=year)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        # open images and do transformation  img = jpg, mask = png
        img = Image.open(self.images[index]).convert("RGB")
        target = Image.open(self.masks[index]).convert("RGB")

        if self.transform:
            img = self.transform(img)
            trfm = T.Compose([T.ToTensor(), T.Resize((388, 388))])
            target = trfm(target)
        return img, target


and lastly here is my train function

import torch
import torch.nn as nn
import torch.optim as optim
from unet import UNet
from torch.utils.data import DataLoader
from dataset import VocDataset
import torchvision.transforms as T
import torch.nn.functional as F

# Hyperparameters etc.
STD = [0.2686, 0.2652, 0.2812]  # Std for dataset
MEAN = [0.4568, 0.4431, 0.4083]  # Mean for dataset
MOMENTUM = 0.9
LEARNING_RATE = 1e-4
BATCH_SIZE = 32
NUM_EPOCHS = 1
NUM_WORKERS = 2
NUM_CLASSES = 20
TRAIN_SET = "train"
VAL_SET = "val"
ENC_CHANNELS = (3, 64, 128, 256, 512, 1024)  # Encoder channels
DEC_CHANNELS = (1024, 512, 256, 128, 64)  # Decoder channels
TRANSFORM = T.Compose(
    [T.ToTensor(), T.Resize(SIZE), T.Normalize(MEAN, STD)]
)


def main():
    training_data = VocDataset(TRAIN_SET, TRANSFORM)
    train_dataloader = DataLoader(
        training_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, drop_last=True)


    # Create instance of unet
    unet = UNet(ENC_CHANNELS, DEC_CHANNELS, NUM_CLASSES)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(
        unet.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

    for epoch in range(NUM_EPOCHS):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(train_dataloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data  # Shape for labels and inputs are: [32,3,388,388] 

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = unet(inputs) # output shape is [32, 32, 388, 388] 

            loss = criterion(outputs, labels)  # Error here

            loss.backward()
            optimizer.step()


    # print('Finished Training')


if __name__ == "__main__":
    main()

like image 201
Mohammed Abdu Avatar asked Nov 18 '25 21:11

Mohammed Abdu


1 Answers

For starters, your label and outputs have different dimension. (32 vs 3 channels). Cross Entropy Loss expects them to either have the same number of channels, or for the target to have only one channel with integer values indicating the relevant class.

Let's work with the latter case. In this case, we need to reduce the target to be a single channel [32 x 388 x 388] for your input and batch size. (Secondarily, the Unet should ideally have one output channel for each class (looks like there are 22 classes so you should change the final output layer of the Unet decoder to have 22 outputs)).

To convert the label of size [32 x 3 x 388 x 388] to [32 x 388 x 388], you need to use the colormap for conversion. That is, create a new tensor target of size [32 x 1 x 388 x 388]. For each value target[i,j,k], assign the index into VOC_COLORMAP that matches the value stored in the pixels at label[i,:,j,k].

like image 188
DerekG Avatar answered Nov 22 '25 05:11

DerekG



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!