Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Keras multi-class prediction output is limited to one class

I constructed a sequential keras model with 35000 input samples and 20 predictors, the test data output classes distribution is :

  • Class_0 = 5.6%
  • Class_1 = 7.7%
  • Class_2 = 35.6%
  • Class_3 = 45.7%
  • Class_4 = 5.4%

After transforming the outputs into binary class matrix utilizing (np_utils.to_categorical) the training accuracy is around 54%, when i do model fitting with test data (15000 samples), all predictions (100%) happen to be for the same class which is class_3 "highest occurrence in training output", what is the cause of this bias and not having a single prediction for other classes? how to make the model sensitive for predicting fewer classes and improve the accuracy especially if the concurrence in training data is low like 1 - 3%.

model = Sequential()
model.add(Dense(40, input_dim=20,  activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(10,  activation='relu'))
model.add(Dense(5, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

model.fit(X, Y, epochs=500,  verbose=1)
like image 895
Ray Avatar asked Apr 02 '17 21:04

Ray


2 Answers

The easiest way to rebalance your loss is to:

  1. Define a class_weights dict:

    class_weights_dict = {0: 1.0 / 0_class_freq, 1: 1.0 / 1_class_freq, ... }
    

    Where i_class_freq is a frequency of i-th class.

  2. Change your fit function to:

    model.fit(X, Y, epochs=500,  verbose=1, class_weight=class_weights_dict)
    

The model defined above should be equivalent to your model with Bayessian reweighted classes.

like image 178
Marcin Możejko Avatar answered Nov 05 '22 04:11

Marcin Możejko


One way to solve is by over sampling under represented class examples. i.e., if you have two classes A(66.6%) and B(33.3%), then you sample B twice compared to A. To make it further simple, you can change your dataset by duplicating B once and making the dataset something like A+2*B.

Or, you can modify your loss function such that it gives more weight when misclassifying the under represented classes.

like image 1
vikasreddy Avatar answered Nov 05 '22 04:11

vikasreddy