Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I implement a weighted cross entropy loss in tensorflow using sparse_softmax_cross_entropy_with_logits

I am starting to use tensorflow (coming from Caffe), and I am using the loss sparse_softmax_cross_entropy_with_logits. The function accepts labels like 0,1,...C-1 instead of onehot encodings. Now, I want to use a weighting depending on the class label; I know that this could be done maybe with a matrix multiplication if I use softmax_cross_entropy_with_logits (one hot encoding), Is there any way to do the same with sparse_softmax_cross_entropy_with_logits?

like image 712
Roger Trullo Avatar asked Oct 23 '16 00:10

Roger Trullo


People also ask

What is weighted cross entropy?

The weighted Cross-Entropy loss function is used to solve the problem that the accuracy of the deep learning model overfitting on the test set due to the imbalance of the convergence speed of the loss function decreases.

What is cross entropy in Tensorflow?

Cross entropy can be used to define a loss function (cost function) in machine learning and optimization. It is defined on probability distributions, not single values. It works for classification because classifier output is (often) a probability distribution over class labels.

What is Softmax cross entropy with Logits?

Softmax Activation Function — How It Actually Works In the above Figure, Softmax converts logits into probabilities. The purpose of the Cross-Entropy is to take the output probabilities (P) and measure the distance from the truth values (as shown in Figure below). Cross Entropy (L) (Source: Author).


3 Answers

import  tensorflow as tf
import numpy as np

np.random.seed(123)
sess = tf.InteractiveSession()

# let's say we have the logits and labels of a batch of size 6 with 5 classes
logits = tf.constant(np.random.randint(0, 10, 30).reshape(6, 5), dtype=tf.float32)
labels = tf.constant(np.random.randint(0, 5, 6), dtype=tf.int32)

# specify some class weightings
class_weights = tf.constant([0.3, 0.1, 0.2, 0.3, 0.1])

# specify the weights for each sample in the batch (without having to compute the onehot label matrix)
weights = tf.gather(class_weights, labels)

# compute the loss
tf.losses.sparse_softmax_cross_entropy(labels, logits, weights).eval()
like image 105
mauna Avatar answered Sep 24 '22 11:09

mauna


Specifically for binary classification, there is weighted_cross_entropy_with_logits, that computes weighted softmax cross entropy.

sparse_softmax_cross_entropy_with_logits is tailed for a high-efficient non-weighted operation (see SparseSoftmaxXentWithLogitsOp which uses SparseXentEigenImpl under the hood), so it's not "pluggable".

In multi-class case, your option is either switch to one-hot encoding or use tf.losses.sparse_softmax_cross_entropy loss function in a hacky way, as already suggested, where you will have to pass the weights depending on the labels in a current batch.

like image 28
Maxim Avatar answered Sep 24 '22 11:09

Maxim


The class weights are multiplied by the logits, so that still works for sparse_softmax_cross_entropy_with_logits. Refer to this solution for "Loss function for class imbalanced binary classifier in Tensor flow."

As a side note, you can pass weights directly into sparse_softmax_cross_entropy

tf.contrib.losses.sparse_softmax_cross_entropy(logits, labels, weight=1.0, scope=None)

This method is for cross-entropy loss using

tf.nn.sparse_softmax_cross_entropy_with_logits.

Weight acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If weight is a tensor of size [batch_size], then the loss weights apply to each corresponding sample.

like image 45
Alexa Greenberg Avatar answered Sep 25 '22 11:09

Alexa Greenberg