Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to implement sparse mean squared error loss in Keras

I wanted to modify the following keras mean squared error loss (MSE) such that the loss is only computed sparsely.

def mean_squared_error(y_true, y_pred): return K.mean(K.square(y_pred - y_true), axis=-1)

My output y is a 3 channel image, where the 3rd channel is non-zero at only those pixels where loss is to be computed. Any idea how can I modify the above to compute sparse loss?

like image 498
delhi_loafer Avatar asked Mar 09 '23 19:03

delhi_loafer


1 Answers

This is not the exact loss you are looking for, but I hope it will give you a hint to write your function (see also here for a Github discussion):

def masked_mse(mask_value):
    def f(y_true, y_pred):
        mask_true = K.cast(K.not_equal(y_true, mask_value), K.floatx())
        masked_squared_error = K.square(mask_true * (y_true - y_pred))
        masked_mse = (K.sum(masked_squared_error, axis=-1) /
                      K.sum(mask_true, axis=-1))
        return masked_mse
    f.__name__ = 'Masked MSE (mask_value={})'.format(mask_value)
    return f

The function computes the MSE loss over all the values of the predicted output, except for those elements whose corresponding value in the true output is equal to a masking value (e.g. -1).

Two notes:

  • when computing the mean the denominator must be the count of non-masked values and not the dimension of the array, that's why I'm not using K.mean(masked_squared_error, axis=1) and I'm instead averaging manually.
  • the masking value must be a valid number (i.e. np.nan or np.inf will not do the job), which means that you'll have to adapt your data so that it does not contain the mask_value.

In this example, the target output is always [1, 1, 1, 1], but some prediction values are progressively masked.

y_pred = K.constant([[ 1, 1, 1, 1], 
                     [ 1, 1, 1, 3],
                     [ 1, 1, 1, 3],
                     [ 1, 1, 1, 3],
                     [ 1, 1, 1, 3],
                     [ 1, 1, 1, 3]])
y_true = K.constant([[ 1, 1, 1, 1],
                     [ 1, 1, 1, 1],
                     [-1, 1, 1, 1],
                     [-1,-1, 1, 1],
                     [-1,-1,-1, 1],
                     [-1,-1,-1,-1]])

true = K.eval(y_true)
pred = K.eval(y_pred)
loss = K.eval(masked_mse(-1)(y_true, y_pred))

for i in range(true.shape[0]):
    print(true[i], pred[i], loss[i], sep='\t')

The expected output is:

[ 1.  1.  1.  1.]  [ 1.  1.  1.  1.]  0.0
[ 1.  1.  1.  1.]  [ 1.  1.  1.  3.]  1.0
[-1.  1.  1.  1.]  [ 1.  1.  1.  3.]  1.33333
[-1. -1.  1.  1.]  [ 1.  1.  1.  3.]  2.0
[-1. -1. -1.  1.]  [ 1.  1.  1.  3.]  4.0
[-1. -1. -1. -1.]  [ 1.  1.  1.  3.]  nan
like image 184
baldassarreFe Avatar answered Mar 23 '23 15:03

baldassarreFe