Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why doesn't class_weight.compute_weight() work?

Hello I am using the class_wight.compute_class_weight() function from the sklearn utils module.

I have an ImageDataGenerator().flow_from_directory() variable that is train_gen.

here is the code:

from sklearn.utils import class_weight  import numpy as np

class_weights = class_weight.compute_class_weight(
           'balanced',
            np.unique(train_gen.classes), 
            train_gen.classes)

# train_class_weights = dict(enumerate(class_weights))
# model.fit_generator(..., class_weight=train_class_weights)

and i obtain this error:

TypeError                                 Traceback (most recent call last)
<ipython-input-50-d468c4be76b8> in <module>()
      5            'balanced',
      6             np.unique(train_gen.classes),
----> 7             train_gen.classes)
      8 
      9 # train_class_weights = dict(enumerate(class_weights))

TypeError: compute_class_weight() takes 1 positional argument but 3 were given

Does anybody know what the problem could be? thank you

like image 592
James__pxlwk Avatar asked Apr 13 '26 20:04

James__pxlwk


1 Answers

You should specify all the arguments in compute_class_weight function:

class_weights = class_weight.compute_class_weight(class_weight='balanced',
                classes=np.unique(train_gen.classes), 
                y=train_gen.classes)
like image 192
morteza kouhsar Avatar answered Apr 19 '26 20:04

morteza kouhsar



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!