Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How does TensorFlow/Keras's class_weight parameter of the fit() function work?

I do semantic segmentation with TensorFlow 1.12 and Keras. I supply a vector of weights (size equal to the number of classes) to tf.keras.Model.fit() using its class_weight parameter. I was wondering how this works internally. I use a custom loss function(s) (dice loss and focal loss amongst others), and the weights cannot be premultiplied with the predictions or the one-hot ground truth before being fed to the loss function, since that wouldn't make any sense. My loss function outputs one scalar value, so it also cannot be multiplied with the function output. So where and how exactly as the class weights taken into account?

My custom loss function is:

def cross_entropy_loss(onehots_true, logits): # Inputs are [BATCH_SIZE, height, width, num_classes]
    logits, onehots_true = mask_pixels(onehots_true, logits) # Removes pixels for which no ground truth exists, and returns shape [num_gt_pixels, num_classes]
    return tf.losses.softmax_cross_entropy(onehots_true, logits)
like image 540
EmielBoss Avatar asked Sep 14 '19 10:09

EmielBoss


2 Answers

As mentioned in the Keras Official Docs,

class_weight: Optional dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function (during training only). This can be useful to tell the model to "pay more attention" to samples from an under-represented class.

Basically, we provide class weights where we have a class imbalance. Meaning, the training samples are not uniformly distributed among all the classes. Some classes have fewer samples whereas some classes have higher samples.

We need the classifier to make more attention to the classes which are less in number. One way could be to increase the loss value for classes with low samples. A higher loss means higher optimization which results in efficient classification.

In terms of Keras, we pass a dict mapping class indices to their weights ( factors by which the loss value will be multiplied ). Let's take an example,

class_weights = { 0 : 1.2 , 1 : 0.9 }

Internally, the loss values for classes 0 and 1 will be multiplied by their corresponding weight values.

weighed_loss_class0 = loss0 * class_weights[0]
weighed_loss_class1 = loss1 * class_weights[1]

Now, the weighed_loss_class0 and weighed_loss_class1 will be used for backpropagation.

See this and this.

like image 81
Shubham Panchal Avatar answered Dec 05 '22 20:12

Shubham Panchal


You can refer to the below code from keras source code in github:

    class_sample_weight = np.asarray(
        [class_weight[cls] for cls in y_classes if cls in class_weight])

    if len(class_sample_weight) != len(y_classes):
      # subtract the sets to pick all missing classes
      existing_classes = set(y_classes)
      existing_class_weight = set(class_weight.keys())
      raise ValueError(
          '`class_weight` must contain all classes in the data.'
          ' The classes %s exist in the data but not in '
          '`class_weight`.' % (existing_classes - existing_class_weight))

  if class_sample_weight is not None and sample_weight is not None:
    # Multiply weights if both are provided.
    return class_sample_weight * sample_weight

so as you can see, first class_weight is transformed into a numpy array class_sample_weight and then it is multiplied with the sample_weight.

source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/engine/training_utils.py

like image 39
eugen Avatar answered Dec 05 '22 20:12

eugen