Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to balance dataset using fit_generator() in Keras?

I am trying to use keras to fit a CNN model to classify 2 classes of data . I have imbalanced dataset I want to balance the data. I don't know can I use class_weight in model.fit_generator . I wonder if I used class_weight="balanced" in model.fit_generator

The main code:

def generate_arrays_for_training(indexPat, paths, start=0, end=100):      
    while True:
        from_=int(len(paths)/100*start)
        to_=int(len(paths)/100*end)
        for i in range(from_, int(to_)):
            f=paths[i]
            x = np.load(PathSpectogramFolder+f) 
            x = np.expand_dims(x, axis=0) 
            
            if('P' in f):
                y = np.repeat([[0,1]],x.shape[0], axis=0)
            else:
                y =np.repeat([[1,0]],x.shape[0], axis=0)
            yield(x,y)   
history=model.fit_generator(generate_arrays_for_training(indexPat, filesPath, end=75), 
                                validation_data=generate_arrays_for_training(indexPat, filesPath, start=75),
                                steps_per_epoch=int((len(filesPath)-int(len(filesPath)/100*25))), 
                                validation_steps=int((len(filesPath)-int(len(filesPath)/100*75))),
                                verbose=2,
                                epochs=15, max_queue_size=2, shuffle=True, callbacks=[callback])

like image 280
Edayildiz Avatar asked Nov 07 '22 04:11

Edayildiz


1 Answers

If you don't want to change your data creation process, you can use class_weight in your fit generator. You can use dictionary to set your class_weight and observe with fine tuning. For instance when class_weight is not used, and you have 50 examples for class0 and 100 examples for class1. Then, loss function calculate loss uniformly. It means that class1 will be a problem. But, when you set:

class_weight = {0:2 , 1:1}

It means that loss function will give 2 times weight to your class 0 now. Therefore, misclassification of underrepresented data will take 2 times more punishment than before. Thus, model can handle imbalanced data.

If you use class_weight='balanced' model can make that setting automatically. But my suggestion is that, create a dictionary like class_weight = {0:a1 , 1:a2} and try different values for a1 and a2, so you can understand difference.

Also, you can use undersampling methods for imbalanced data instead of using class_weight. Check Bootstrapping methods for that purpose.

like image 123
Angerato Avatar answered Nov 14 '22 21:11

Angerato