Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensor indexing in custom loss function

Basically, I want my custom loss function to alternate between usual MSE and a custom MSE that subtract values from different indexes.

To clarify, let's say I have a y_pred tensor that is [1, 2, 4, 5] and a y_true tensor that is [2, 5, 1, 3]. In usual MSE, we should get:

return K.mean(K.squared(y_pred - y_true))

That would do the following:

[1, 2, 4, 5] - [2, 5, 1, 3] = [-1, -3, 3, 2]

[-1, -3, 3, 2]² = [1, 9, 9, 4]

mean([1, 9, 9, 4]) = 5.75

I need my custom loss function to select the minimum value between this mean and other that switches indexes 1 and 3 from the y_pred tensor, i.e.:

[1, 5, 4, 2] - [2, 5, 1, 3] = [-1, 0, 3, 1]

[-1, 0, 3, 1]² = [1, 0, 9, 1]

mean([1, 0, 9, 1]) = 2.75

So, my custom loss would return 2.75, which is the minimum value between both means. To do this, I tried to transform y_true and y_pred tensors in numpy arrays, doing all math related as following:

def new_mse(y_true, y_pred):
    sess = tf.Session()

    with sess.as_default():
        np_y_true = y_true.eval()
        np_y_pred = y_pred.eval()

        np_err_mse = np.empty(np_y_true.shape)
        np_err_mse = np.square(np_y_pred - np_y_true)

        np_err_new_mse = np.empty(np_y_true.shape)
        l0 = np.square(np_y_pred[:, 2] - np_y_true[:, 0])   
        l1 = np.square(np_y_pred[:, 3] - np_y_true[:, 1])
        l2 = np.square(np_y_pred[:, 0] - np_y_true[:, 2])
        l3 = np.square(np_y_pred[:, 1] - np_y_true[:, 3])   
        l4 = np.square(np_y_pred[:, 4] - np_y_true[:, 4])
        l5 = np.square(np_y_pred[:, 5] - np_y_true[:, 5])
        np_err_new_mse = np.transpose(np.vstack(l0, l1, l2, l3, l4, l5))

        np_err_mse = np.mean(np_err_mse)
        np_err_new_mse = np.mean(np_err_new_mse)

        return np.amin([np_err_mse, np_err_new_mse])

Problem is that I can't use eval() method with y_true and y_pred tensors, not sure why. Finally, my questions are:

  1. Is there an easier method to work with indexing inside tensors and loss functions? I'm newbie to Tensorflow and Keras in general and I strongly believe transforming everything in numpy arrays isn't the optimal method at all.
  2. Not totally related to question, but when I tried to print y_true tensor's shape with K.shape(y_true), I got "Tensor("Shape_1:0", shape=(2,), dtype=int32)". That confuses me, since I'm working with a y.shape equal to (7032, 6), that is 7032 images with 6 labels each. Probably may be some misinterpretation related to my y and y_pred used by loss function.
like image 715
Mateus Hora Avatar asked Sep 13 '17 14:09

Mateus Hora


2 Answers

Often you work just with backend functions, and you never try to know the actual values of the tensors.

from keras.losses import mean_square_error

def new_mse(y_true,y_pred): 

    #swapping elements 1 and 3 - concatenate slices of the original tensor
    swapped = K.concatenate([y_pred[:1],y_pred[3:],y_pred[2:3],y_pred[1:2]])
    #actually, if the tensors are shaped like (batchSize,4), use this:
    #swapped = K.concatenate([y_pred[:,:1],y_pred[:,3:],y_pred[:,2:3],Y_pred[:,1:2])

    #losses
    regularLoss = mean_squared_error(y_true,y_pred)
    swappedLoss = mean_squared_error(y_true,swapped)

    #concat them for taking a min value
    concat = K.concatenate([regularLoss,swappedLoss])

    #take the minimum
    return K.min(concat)

So, for your items:

  1. You're totally right. Avoid numpy at all costs in tensor operations (loss functions, activations, custom layers, etc.)

  2. A K.shape() is also a tensor. It probably has shape (2,) because it has two values, one value will be 7032, the other value will be 6. But you can only see these values when you eval this tensor. Doing this inside loss functions is often a bad idea.

like image 150
Daniel Möller Avatar answered Oct 22 '22 04:10

Daniel Möller


If using Keras 2 you should just use the K.gather function to perform indexing.

Daniel Möller's answer becomes:

from keras.losses import mean_square_error

def reindex(t, perm):
    K.concatenate([K.gather(t, i) for i in perm])

def my_loss(y_true,y_pred):

    #losses
    regularLoss = mean_squared_error(y_true, y_pred)
    swappedLoss = mean_squared_error(y_true, reindex(y_pred, [0,3,2,1]))

    #concat them for taking a min value
    concat = K.concatenate([regularLoss,swappedLoss])

    #take the minimum
    return K.min(concat)
like image 42
hawk78 Avatar answered Oct 22 '22 04:10

hawk78