Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to weight classes using fit_generator() in Keras?

Tags:

python

keras

I am trying to use keras to fit a CNN model to classify images. The data set has much more images form certain classes, so its unbalanced.

I read different thing on how to weight the loss to account for this in Keras, e.g.: https://datascience.stackexchange.com/questions/13490/how-to-set-class-weights-for-imbalanced-classes-in-keras, which is nicely explained. But, its always explaining for the fit() function, not the fit_generator() one.

Indeed, in the fit_generator() function we dont have the 'class_weights' parameter, but instead we have 'weighted_metrics', which I dont understand its description: "weighted_metrics: List of metrics to be evaluated and weighted by sample_weight or class_weight during training and testing."

How can I pass from 'class_weights' to 'weighted_metrics'? Would any one have a simple example?

like image 821
miki Avatar asked Aug 31 '18 15:08

miki


People also ask

What is Fit_generator in Keras?

fit_generator() function first accepts a batch of the dataset, then performs backpropagation on it, and then updates the weights in our model. For the number of epochs specified(10 in our case) the process is repeated.

How do you calculate class weight?

Generating class weights In binary classification, class weights could be represented just by calculating the frequency of the positive and negative class and then inverting it so that when multiplied to the class loss, the underrepresented class has a much higher error than the majority class.

How do you give class weights in Sklearn?

If 'balanced', class weights will be given by n_samples / (n_classes * np. bincount(y)) . If a dictionary is given, keys are classes and values are corresponding class weights. If None is given, the class weights will be uniform.


1 Answers

We have class_weight in fit_generator (Keras v.2.2.2) According to docs:

Class_weight: Optional dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function (during training only). This can be useful to tell the model to "pay more attention" to samples from an under-represented class.

Assume you have two classes [positive and negative], you can pass class_weight to fit_generator with:

model.fit_generator(gen,class_weight=[0.7,1.3])
like image 149
Amir Avatar answered Sep 28 '22 06:09

Amir