Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Keras: what does class_weight actually try to balance?

My data has extreme class imbalance. About 99.99% of samples are negatives; the positives are (roughly) equally divided among three other classes. I think the models I'm training are just predicting the majority class basically all the time. For this reason, I'm trying to weight the classes.

Model

model = Sequential()

#Layer 1
model.add(Conv1D( {{choice([32, 64, 90, 128])}}, {{choice([3, 4, 5, 6, 8])}}, activation='relu', kernel_initializer=kernel_initializer, input_shape=X_train.shape[1:]))
model.add(BatchNormalization())

#Layer 2
model.add(Conv1D({{choice([32, 64, 90, 128])}}, {{choice([3, 4, 5, 6])}}, activation='relu',kernel_initializer=kernel_initializer))
model.add(Dropout({{uniform(0, 0.9)}}))

#Flatten
model.add(Flatten())

#Output
model.add(Dense(4, activation='softmax'))

(The {{...}} are for use with Hyperas.)

How I've tried to weight it

\1. Using class_weight in model.fit()

model.fit(X_train, Y_train, batch_size=64, epochs=10, verbose=2, validation_data=(X_test, Y_test), class_weight={0: 9999, 1:9999, 2: 9999, 3:1})

\2. Using class_weight in model.fit() with sklearn compute_class_weight()

model.fit(..., class_weight=class_weight.compute_class_weight("balanced", np.unique(Y_train), Y_train)

\3. With a custom loss function

from keras import backend as K
def custom_loss(weights):
    #gist.github.com/wassname/ce364fddfc8a025bfab4348cf5de852d

    def loss(Y_true, Y_pred):
        Y_pred /= K.sum(Y_pred, axis=-1, keepdims=True)
        Y_pred = K.clip(Y_pred, K.epsilon(), 1 - K.epsilon())

        loss = Y_true * K.log(Y_pred) * weights
        loss = -K.sum(loss, -1)
        return loss

    return loss

extreme_weights = np.array([9999, 9999, 9999, 1])
model.compile(loss=custom_loss(extreme_weights),
            metrics=['accuracy'],
            optimizer={{choice(['rmsprop', 'adam', 'sgd','Adagrad','Adadelta'])}}
            )

#(then fit *without* class_weight)

Results

Poor. Accuracy across all classes is ~.99, and unbalanced accuracy for all classes is ~.5. But more meaningful metrics, like auPRC, tell a different story. The auPRC is nearly 1 for the majority class, and nearly 0 for the rest.

Is this how Keras balances classes? It just makes sure that the accuracy is the same across them—or should either metrics be equal or comparable too? Or am I specifying the weights wrong?

like image 572
Randoms Avatar asked Jul 19 '18 23:07

Randoms


People also ask

What does Class_weight balanced do?

Balanced class weights can be automatically calculated within the sample weight function. Set class_weight = 'balanced' to automatically adjust weights inversely proportional to class frequencies in the input data (as shown in the above table).

How does Class_weight work in Keras?

class_weights is used to provide a weight or bias for each output class. This means you should pass a weight for each class that you are trying to classify. sample_weight must be given a numpy array, since its shape will be evaluated. See also this answer.

What is Class_weight in logistic regression?

The LogisticRegression class provides the class_weight argument that can be specified as a model hyperparameter. The class_weight is a dictionary that defines each class label (e.g. 0 and 1) and the weighting to apply in the calculation of the negative log likelihood when fitting the model.

How does class weight work?

Class weights give all the classes equal importance on gradient updates, on average, regardless of how many samples we have from each class in the training data. This prevents models from predicting the more frequent class more often just because it's more common.


1 Answers

Keras uses the class weights during training but the accuracy is not reflective of that. Accuracy is calculated across all samples irrelevant of the weight between classes. This is because you're using the metric 'accuracy' in the compile(). You can define a custom and more accurate weighted accuracy and use that or use the sklearn metrics (e.g. f1_score() which can be 'binary', 'weighted' etc).

Example:

def macro_f1(y_true, y_pred):
     return f1_score(y_true, y_pred, average='macro')


model.compile(loss=custom_loss(extreme_weights),
        metrics=['accuracy', macro_f1],
        optimizer={{choice(['rmsprop', 'adam', 'sgd','Adagrad','Adadelta'])}}
        )
like image 173
dter Avatar answered Oct 03 '22 08:10

dter