Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Imbalanced Dataset for Multi Label Classification

So I trained a deep neural network on a multi label dataset I created (about 20000 samples). I switched softmax for sigmoid and try to minimize (using Adam optimizer) :

tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=y_pred)

And I end up with this king of prediction (pretty "constant") :

Prediction for Im1 : [ 0.59275776  0.08751075  0.37567005  0.1636796   0.42361438  0.08701646 0.38991812  0.54468459  0.34593087  0.82790571]

Prediction for Im2 : [ 0.52609032  0.07885984  0.45780018  0.04995904  0.32828355  0.07349177 0.35400775  0.36479294  0.30002621  0.84438241]

Prediction for Im3 : [ 0.58714485  0.03258472  0.3349618   0.03199361  0.54665488  0.02271551 0.43719986  0.54638696  0.20344526  0.88144571]

At first, I thought I just neeeded to find a threshold value for each class.

But I noticed that, for instance, among my 20000 samples, the 1st class appears about 10800 so a 0.54 ratio and it the value around which my prediction is every time. So I think I need to find a way to tackle tuis "imbalanced datset" issue.

I thought about reducing my dataset (Undersampling) to have about the same number of occurence for each class but only 26 samples correspond to one of my classes... That would make me loose a lot of samples...

I read about oversampling or about penalizing even more the classes that are rare but did not really understood how it works.

Can someone share some explainations about these methods please ?

In practice, on Tensorflow, are there functions that help doing that ?

Any other suggestions ?

Thank you :)

PS: Neural Network for Imbalanced Multi-Class Multi-Label Classification This post raises the same problem but had no answer !

like image 754
A. Piro Avatar asked May 31 '17 13:05

A. Piro


2 Answers

Well, having 10000 samples in one class and just 26 in a rare class will be indeed a problem.

However, what you experience, to me, seems more like "outputs don't even see the inputs" and thus the net just learns your output distribution.

To debug this I would create a reduced set (just for this debugging purpose) with say 26 samples per class and then try to heavily overfit. If you get correct predictions my thought is wrong. But if the net cannot even detect those undersampled overfit samples then indeed it's an architecture/implementation problem and not due to the schewed distribution (which you will then need to fix. But it'll be not as bad as your current results).

like image 174
Phillip Bock Avatar answered Sep 18 '22 17:09

Phillip Bock


Your problem is not the class imbalance, rather just the lack of data. 26 samples are considered to be a very small dataset for practically any real machine learning task. A class imbalance could be easily handled by ensuring that each minibatch will have at least one sample from every class (this leads to situations when some samples will be used much more frequently than another, but who cares).

However, in the case of presence only 26 samples this approach (and any other) will quickly lead to overfitting. This problem could be partly solved with some form of data augmentation, but there still too few samples to construct something reasonable.

So, my suggestion will be to collect more data.

like image 37
Dmitriy Danevskiy Avatar answered Sep 20 '22 17:09

Dmitriy Danevskiy