Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Keras - custom loss function - chamfer distance

I am attempting object segmentation using a custom loss function as defined below:

def chamfer_loss_value(y_true, y_pred):           

    # flatten the batch 
    y_true_f = K.batch_flatten(y_true)
    y_pred_f = K.batch_flatten(y_pred)

    # ==========
    # get chamfer distance sum

    // error here
    y_pred_mask_f = K.cast(K.greater_equal(y_pred_f,0.5), dtype='float32')

    finalChamferDistanceSum = K.sum(y_pred_mask_f * y_true_f, axis=1, keepdims=True)  

    return K.mean(finalChamferDistanceSum)

def chamfer_loss(y_true, y_pred):   
    return chamfer_loss_value(y_true, y_pred)

y_pred_f is the result of my U-net. y_true_f is the result of a euclidean distance transform on the ground truth label mask x as shown below:

distTrans = ndimage.distance_transform_edt(1 - x)

To compute the Chamfer distance, you multiply the predicted image (ideally, a mask with 1 and 0) with the ground truth distance transform, and simply sum over all pixels. To do this, I needed to get a mask y_pred_mask_f by thresholding y_pred_f, then multiply with y_true_f, and sum over all pixels.

y_pred_f provides a continuous range of values in [0,1], and I get the error None type not supported at the evaluation of y_true_mask_f. I know the loss function has to be differentiable, and greater_equal and cast are not. But, is there a way to circumvent this in Keras? Perhaps using some workaround in Tensorflow?

like image 601
Eagle Avatar asked Jan 29 '23 18:01

Eagle


1 Answers

Well, this was tricky. The reason behind your error is that there is no continuous dependence between your loss and your network. In order to compute gradients of your loss w.r.t. to network, your loss must compute the gradient of indicator if your output is greater than 0.5 (as this is the only connection between your final loss value and output y_pred from your network). This is impossible as this indicator is partially constant and not continuous.

Possible solution - smooth your indicator:

def chamfer_loss_value(y_true, y_pred):           

    # flatten the batch 
    y_true_f = K.batch_flatten(y_true)
    y_pred_f = K.batch_flatten(y_pred)

    y_pred_mask_f = K.sigmoid(y_pred_f - 0.5)

    finalChamferDistanceSum = K.sum(y_pred_mask_f * y_true_f, axis=1, keepdims=True)  

    return K.mean(finalChamferDistanceSum)

As sigmoid is a continuous version of a step function. If your output comes from sigmoid - you could simply use y_pred_f instead of y_pred_mask_f.

like image 183
Marcin Możejko Avatar answered Jan 31 '23 18:01

Marcin Możejko