Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Keras: weighted binary crossentropy

I tried to implement a weighted binary crossentropy with Keras, but I am not sure if the code is correct. The training output seems to be a bit confusing. After a few epochs I just get an accuracy of ~0.15. I think thats much too less (even for a random guess).

There are in general about 11% ones in the output and 89% zeros, therefore the weights are w_zero=0.89 and w_one=0.11.

My code:

def create_weighted_binary_crossentropy(zero_weight, one_weight):      def weighted_binary_crossentropy(y_true, y_pred):          # Original binary crossentropy (see losses.py):         # K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1)          # Calculate the binary crossentropy         b_ce = K.binary_crossentropy(y_true, y_pred)          # Apply the weights         weight_vector = y_true * one_weight + (1. - y_true) * zero_weight         weighted_b_ce = weight_vector * b_ce          # Return the mean error         return K.mean(weighted_b_ce)      return weighted_binary_crossentropy 

Maybe someone sees whats wrong?

Thank you

like image 767
Kevin Meier Avatar asked Sep 02 '17 01:09

Kevin Meier


People also ask

What is weighted binary cross-entropy?

Keras-Weighted-Binary-Cross-Entropy This modifies the binary cross entropy function found in keras by addind a weighting. This weight is determined dynamically for every batch by identifying how many positive and negative classes are present and modifying accordingly.

What is binary Crossentropy?

Binary crossentropy is a loss function that is used in binary classification tasks. These are tasks that answer a question with only two choices (yes or no, A or B, 0 or 1, left or right).

How is binary Crossentropy calculated?

Binary cross entropy compares each of the predicted probabilities to actual class output which can be either 0 or 1. It then calculates the score that penalizes the probabilities based on the distance from the expected value. That means how close or far from the actual value.

What is binary cross-entropy keras?

The Binary Cross entropy will calculate the cross-entropy loss between the predicted classes and the true classes. By default, the sum_over_batch_size reduction is used. This means that the loss will return the average of the per-sample losses in the batch.


2 Answers

You can use the sklearn module to automatically calculate the weights for each class like this:

# Import import numpy as np from sklearn.utils import class_weight  # Example model model = Sequential() model.add(Dense(32, activation='relu', input_dim=100)) model.add(Dense(1, activation='sigmoid'))  # Use binary crossentropy loss model.compile(optimizer='rmsprop',               loss='binary_crossentropy',               metrics=['accuracy'])  # Calculate the weights for each class so that we can balance the data weights = class_weight.compute_class_weight('balanced',                                             np.unique(y_train),                                             y_train)  # Add the class weights to the training                                          model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights) 

Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].

like image 73
tsveti_iko Avatar answered Sep 20 '22 12:09

tsveti_iko


Normally, the minority class will have a higher class weight. It'll be better to use one_weight=0.89, zero_weight=0.11 (btw, you can use class_weight={0: 0.11, 1: 0.89}, as suggested in the comment).

Under class imbalance, your model is seeing much more zeros than ones. It will also learn to predict more zeros than ones because the training loss can be minimized by doing so. That's also why you're seeing an accuracy close to the proportion 0.11. If you take an average over model predictions, it should be very close to zero.

The purpose of using class weights is to change the loss function so that the training loss cannot be minimized by the "easy solution" (i.e., predicting zeros), and that's why it'll be better to use a higher weight for ones.

Note that the best weights are not necessarily 0.89 and 0.11. Sometimes you might have to try something like taking logarithms or square roots (or any weights satisfying one_weight > zero_weight) to make it work.

like image 34
Yu-Yang Avatar answered Sep 16 '22 12:09

Yu-Yang