Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Having issues with neural network training. Loss not decreasing

I'm largely following this project but am doing a pixel-wise classification. I have 8 classes and 9 band imagery. My images are gridded into 9x128x128. My loss is not reducing and training accuracy doesn't fluctuate much. I'm guessing I have something wrong with the model. Any advice is much appreciated! I get at least 91% accuracy using random forest.

My classes are extremely unbalanced so I attempted to adjust training weights based on the proportion of classes within the training data.

# get model
learning_rate = 0.0001
model = unet.UNetSmall(8)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# set up weights based on data proportion
weights = np.array([0.79594768, 0.07181202, 0.02347426, 0.0042031, 0.00366211, 0.00764327, 0.07003923, 0.02321833])
weights = (1 - weights)/7
print('Weights of training data based on proportion of the training labels.  Not compted here')
print(weights)
print(sum(weights))
criterion = nn.CrossEntropyLoss(weight = weight)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

Weights of training data based on proportion of the training labels. Not compted here [0.02915033 0.13259828 0.13950368 0.1422567 0.14233398 0.14176525 0.13285154 0.13954024] 1.0000000000000002

I've normalized the data using the transforms.functional.normalize function. I calculated the mean and standard deviation of the training data and added this augmentation to my data loader.

dataset_train = data_utils.SatIn(data_path, 'TrainValTest.csv', 'train', transform=transforms.Compose([aug.ToTensorTarget(), aug.NormalizeTarget(mean=popmean, std=popstd)]))

I augmented my training data in preprocessing by rotating and flipping the imagery. 1 image grid then became 8.

I checked that my training data matched my classes and everything checked out. Since I'm using 8 classes I chose to use CrossEntropyLoss since it has Softmax built in.

Current model

class UNetSmall(nn.Module):
    """
    Main UNet architecture
    """
    def __init__(self, num_classes=1):
        super().__init__()
        # encoding
        self.conv1 = encoding_block(9, 32)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = encoding_block(32, 64)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)
        self.conv3 = encoding_block(64, 128)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2)
        self.conv4 = encoding_block(128, 256)
        self.maxpool4 = nn.MaxPool2d(kernel_size=2)

        # center
        self.center = encoding_block(256, 512)

        # decoding
        self.decode4 = decoding_block(512, 256)
        self.decode3 = decoding_block(256, 128)
        self.decode2 = decoding_block(128, 64)
        self.decode1 = decoding_block(64, 32)

        # final
        self.final = nn.Conv2d(32, num_classes, kernel_size=1)

    def forward(self, input):

        # encoding
        conv1 = self.conv1(input)
        maxpool1 = self.maxpool1(conv1)
        conv2 = self.conv2(maxpool1)
        maxpool2 = self.maxpool2(conv2)
        conv3 = self.conv3(maxpool2)
        maxpool3 = self.maxpool3(conv3)
        conv4 = self.conv4(maxpool3)
        maxpool4 = self.maxpool4(conv4)

        # center
        center = self.center(maxpool4)

        # decoding
        decode4 = self.decode4(conv4, center)
        decode3 = self.decode3(conv3, decode4)
        decode2 = self.decode2(conv2, decode3)
        decode1 = self.decode1(conv1, decode2)
        # final
        final = nn.functional.upsample(self.final(decode1), input.size()[2:], mode='bilinear')
        return final

Training method

def train(train_loader, model, criterion, optimizer, scheduler, epoch_num):

    correct = 0
    totalcount = 0

    scheduler.step()

    # iterate over data
    for idx, data in enumerate(tqdm(train_loader, desc="training")):
        # get the inputs and wrap in Variable
        if torch.cuda.is_available():
            inputs = Variable(data['sat_img'].cuda())
            labels = Variable(data['map_img'].cuda())
        else:
            inputs = Variable(data['sat_img'])
            labels = Variable(data['map_img'])

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels.long())
        loss.backward()
        optimizer.step()

        test = torch.max(outputs.data, 1)[1] == labels.long()
        correct += test.sum().item()
        totalcount += test.size()[0] * test.size()[1] * test.size()[2]

    print('Training Loss: {:.4f}, Accuracy: {:.2f}'.format(loss.data[0], correct/totalcount))
    return {'train_loss': loss.data[0], 'train_acc' : correct/totalcount}

Training call in epoch loop

lr_scheduler.step()
train_metrics = train(train_dataloader, model, criterion, optimizer, lr_scheduler, epoch)

Some epoch iteration output

#### Epoch 0/19

---------- training: 100%|████████████████████████████████████████████████████████████████████████| 84/84 [00:17<00:00, 5.77it/s] Training Loss: 0.8901, Accuracy: 0.83 Current elapsed time 2m 6s

#### Epoch 1/19

---------- training: 100%|████████████████████████████████████████████████████████████████████████| 84/84 [00:17<00:00, 5.72it/s] Training Loss: 0.7922, Accuracy: 0.83 Current elapsed time 2m 24s

#### Epoch 2/19

---------- training: 100%|████████████████████████████████████████████████████████████████████████| 84/84 [00:18<00:00, 5.44it/s] Training Loss: 0.8753, Accuracy: 0.84 Current elapsed time 2m 42s

#### Epoch 3/19

---------- training: 100%|████████████████████████████████████████████████████████████████████████| 84/84 [00:18<00:00, 5.53it/s] Training Loss: 0.7741, Accuracy: 0.84 Current elapsed time 3m 1s

like image 895
mkmitchell Avatar asked Dec 24 '22 01:12

mkmitchell


2 Answers

It's hard to debug your model with those informations, but maybe some of those ideas will help you in some way:

  1. Try to overfit your network on much smaller data and for many epochs without augmenting first, say one-two batches for many epochs. If this one doesn't work, than your model is not capable to model relation between data and desired target or you have an error somewhere. Furthermore it's easier to debug it that way.
  2. I'm not sure about the weights idea, maybe try to upsample underrepresented classes in order to make it more balanced (repeat some underrepresented examples in your dataset). Curious where is this idea from, never heard of it.
  3. Have you tried to run the model from the repo you provided before applying your own customisations? How well it performs, were you able to replicate their findings? Why do you think this architecture would be a good fit for your, from what I understand, different case? Loss function in the link you provided is different, while the architecture is the same. I haven't read this paper, neither have I tried your model, but it seems a little strange.
  4. Link inside GitHub repo points to a blog post, where bigger batches are advised as it stabilizes the training, what is your batch size?
  5. Maybe start with smaller and easier model and work you way up from there?

And the most important coming last; I don't think SO is the best place for such question (especially as it is research oriented), I see you have already asked it on GitHub issues though, maybe try to contact author directly?

If I were you I would start with the last point and thorough understanding of operations and their effect on your goal, good luck.

like image 57
Szymon Maszke Avatar answered Apr 28 '23 13:04

Szymon Maszke


The model did not suit my purpose and I don't know enough about them to know why. I switched to a different unet model found here and everything started working.
Accuracy is up with what random forests is producing. When I attempted to remove weighting I was getting nan as loss. With the new approach loss is reducing down to ~0.2 instead of hovering above 0.5. Training accuracy pretty quickly increased to high high 80s in the first 50 epochs and didn't go above that in the next 50.

I plan on testing a few different models similar to what the authors did in this paper. I'll create a simple base and compare results to UNet and VGG16.

like image 28
mkmitchell Avatar answered Apr 28 '23 13:04

mkmitchell