Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Generalized dice loss for multi-class segmentation: keras implementation

I just implemented the generalised dice loss (multi-class version of dice loss) in keras, as described in ref :

(my targets are defined as: (batch_size, image_dim1, image_dim2, image_dim3, nb_of_classes))

def generalized_dice_loss_w(y_true, y_pred): 
    # Compute weights: "the contribution of each label is corrected by the inverse of its volume"
    Ncl = y_pred.shape[-1]
    w = np.zeros((Ncl,))
    for l in range(0,Ncl): w[l] = np.sum( np.asarray(y_true[:,:,:,:,l]==1,np.int8) )
    w = 1/(w**2+0.00001)

    # Compute gen dice coef:
    numerator = y_true*y_pred
    numerator = w*K.sum(numerator,(0,1,2,3))
    numerator = K.sum(numerator)

    denominator = y_true+y_pred
    denominator = w*K.sum(denominator,(0,1,2,3))
    denominator = K.sum(denominator)

    gen_dice_coef = numerator/denominator

    return 1-2*gen_dice_coef

But something must be wrong. I'm working with 3D images that I have to segment for 4 classes (1 background class and 3 object classes, I have a imbalanced dataset). First odd thing: while my train loss and accuracy improve during training (and converge really fast), my validation loss/accuracy are constant trough epochs (see image). Second, when predicting on test data, only the background class is predicted: I get a constant volume.

I used the exact same data and script but with categorical cross-entropy loss and get plausible results (object classes are segmented). Which means something is wrong with my implementation. Any idea what it could be?

Plus I believe it would be usefull to the keras community to have a generalised dice loss implementation, as it seems to be used in most of recent semantic segmentation tasks (at least in the medical image community).

PS: it seems odd to me how the weights are defined; I get values around 10^-10. Anyone else has tried to implement this? I also tested my function without the weights but get same problems.

like image 278
Manu Avatar asked Feb 27 '18 15:02

Manu


People also ask

How is Dice loss calculated?

Limitations of Cross Entropy Loss In cross entropy loss, the loss is calculated as the average of per-pixel loss, and the per-pixel loss is calculated discretely, without knowing whether its adjacent pixels are boundaries or not .

What is Dice loss in segmentation?

Dice Loss. The Dice coefficient is widely used metric in computer vision community to calculate the similarity between two images. Later in 2016, it has also been adapted as loss function known as Dice Loss [10].

What is dice score in segmentation?

Dice similarity coefficient is a spatial overlap index and a reproducibility validation metric. It was also called the proportion of specific agreement by Fleiss (14). The value of a DSC ranges from 0, indicating no spatial overlap between two sets of binary segmentation results, to 1, indicating complete overlap.

What is multiclass segmentation?

Semantic segmentation is a computer vision task in which every pixel of a given image frame is classified/labelled based on whichever class it belongs to. Typically, Convolutional Neural Networks (CNNs) are used for image segmentation tasks.


1 Answers

I think the problem here are your weights. Imagine you are trying to solve a multiclass segmentation problem, but in each image only a few classes are ever present. A toy example of this (and the one which led me to this problem) is to create a segmentation dataset from mnist in the following way.

x = 28x28 image and y = 28x28x11 where each pixel is classified as background if it is below a normalised grayscale value of 0.4, and otherwise is classified as the digit which is the original class of x. So if you see a picture of the number one, you will have a bunch of pixels classified as one, and the background.

Now in this dataset you will only ever have two classes present in the image. This means that, following your dice loss, 9 of the weights will be 1./(0. + eps) = large and so for every image we are strongly penalising all 9 non-present classes. An evidently strong local minima the network wants to find in this situation is to predict everything as a background class.

We do want to penalise any incorrectly predicted classes which are not in the image, but not so strongly. So we just need to modify the weights. This is how I did it:

def gen_dice(y_true, y_pred, eps=1e-6):
    """both tensors are [b, h, w, classes] and y_pred is in logit form"""

    # [b, h, w, classes]
    pred_tensor = tf.nn.softmax(y_pred)
    y_true_shape = tf.shape(y_true)

    # [b, h*w, classes]
    y_true = tf.reshape(y_true, [-1, y_true_shape[1]*y_true_shape[2], y_true_shape[3]])
    y_pred = tf.reshape(pred_tensor, [-1, y_true_shape[1]*y_true_shape[2], y_true_shape[3]])

    # [b, classes]
    # count how many of each class are present in 
    # each image, if there are zero, then assign
    # them a fixed weight of eps
    counts = tf.reduce_sum(y_true, axis=1)
    weights = 1. / (counts ** 2)
    weights = tf.where(tf.math.is_finite(weights), weights, eps)

    multed = tf.reduce_sum(y_true * y_pred, axis=1)
    summed = tf.reduce_sum(y_true + y_pred, axis=1)

    # [b]
    numerators = tf.reduce_sum(weights*multed, axis=-1)
    denom = tf.reduce_sum(weights*summed, axis=-1)
    dices = 1. - 2. * numerators / denom
    dices = tf.where(tf.math.is_finite(dices), dices, tf.zeros_like(dices))
    return tf.reduce_mean(dices)
like image 82
Ben Avatar answered Jan 04 '23 05:01

Ben