Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pixel-wise loss weight for image segmentation in Keras

I am currently using a modified version of the U-Net (https://arxiv.org/pdf/1505.04597.pdf) to segment cell organelles in microscopy images. Since I am using Keras, I took the code from https://github.com/zhixuhao/unet. However, in this version no weight map is implemented to force the network to learn the border pixels.

The results that I have obtained so far are quite good, but the network fails to separate objects that are close to each other. So I want to try and make use of the weight map mentioned in the paper. I have been able to generate the weight map (based on the given formula) for each label image, but I was unable to find out how to use this weight map to train my network and thus solve the above mentioned problem.

Do weight maps and label images have to be combined somehow or is there a Keras function that will allow me to make use of the weight maps? I am Biologist, who only recently started to work with neural networks, so my understanding is still limited. Any help or advice would be greatly appreciated.

like image 656
disputator1991 Avatar asked May 09 '18 14:05

disputator1991


People also ask

What is the best loss function for image segmentation?

The most commonly used loss function for the task of image segmentation is a pixel-wise cross entropy loss.

What are the loss function for segmentation?

In this blog post, I will focus on three of the more commonly-used loss functions for semantic image segmentation: Binary Cross-Entropy Loss, Dice Loss and the Shape-Aware Loss.

What is pixel-wise segmentation?

Pixel-wise street segmentation of photographs taken from a drivers perspective is important for self-driving cars and can also support other object recognition tasks. A framework called SST was developed to examine the accuracy and execution time of different neural networks.


1 Answers

In case it is still relevant: I needed to solve this recently. You can paste the code below into a Jupyter notebook to see how it works.

%matplotlib inline
import numpy as np
from skimage.io import imshow
from skimage.measure import label
from scipy.ndimage.morphology import distance_transform_edt
import numpy as np

def generate_random_circles(n = 100, d = 256):
    circles = np.random.randint(0, d, (n, 3))
    x = np.zeros((d, d), dtype=int)
    f = lambda x, y: ((x - x0)**2 + (y - y0)**2) <= (r/d*10)**2
    for x0, y0, r in circles:
        x += np.fromfunction(f, x.shape)
    x = np.clip(x, 0, 1)

    return x

def unet_weight_map(y, wc=None, w0 = 10, sigma = 5):

    """
    Generate weight maps as specified in the U-Net paper
    for boolean mask.

    "U-Net: Convolutional Networks for Biomedical Image Segmentation"
    https://arxiv.org/pdf/1505.04597.pdf

    Parameters
    ----------
    mask: Numpy array
        2D array of shape (image_height, image_width) representing binary mask
        of objects.
    wc: dict
        Dictionary of weight classes.
    w0: int
        Border weight parameter.
    sigma: int
        Border width parameter.

    Returns
    -------
    Numpy array
        Training weights. A 2D array of shape (image_height, image_width).
    """

    labels = label(y)
    no_labels = labels == 0
    label_ids = sorted(np.unique(labels))[1:]

    if len(label_ids) > 1:
        distances = np.zeros((y.shape[0], y.shape[1], len(label_ids)))

        for i, label_id in enumerate(label_ids):
            distances[:,:,i] = distance_transform_edt(labels != label_id)

        distances = np.sort(distances, axis=2)
        d1 = distances[:,:,0]
        d2 = distances[:,:,1]
        w = w0 * np.exp(-1/2*((d1 + d2) / sigma)**2) * no_labels
    else:
        w = np.zeros_like(y)
    if wc:
        class_weights = np.zeros_like(y)
        for k, v in wc.items():
            class_weights[y == k] = v
        w = w + class_weights
    return w

y = generate_random_circles()

wc = {
    0: 1, # background
    1: 5  # objects
}

w = unet_weight_map(y, wc)

imshow(w)
like image 148
Rok Avatar answered Sep 21 '22 10:09

Rok