Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Cascade Classifiers for Multiclass Problems in scikit-learn

Say I have a classification problem that is multiclass and characteristically hierarchical, e.g. 'edible', 'nutritious' and '~nutritious' - so it can be represented like so

├── edible
│   ├── nutritious
│   └── ~nutritious
└── ~edible

While one can get reasonable performance with classifiers that support multiclass classification or using one-vs-one/all schemes for those that don't, it may also be beneficial to separately train classifiers at each level and concatenate them so the instances classified as 'edible' can be classified as either nutritious or not.

I would like to use scikit-lean estimators as building blocks and I am wondering if I can make the Pipeline support this or if I would need to write my own estimator that implements the base estimator and possibly BaseEnsemble to do this.

It has been mentioned before by @ogrisel on the mailing list http://sourceforge.net/mailarchive/message.php?msg_id=31417048 and I'm wondering if anyone has insights or suggestions on how to go about doing this.

like image 833
tiao Avatar asked Jan 16 '14 00:01

tiao


1 Answers

You can write your own class as a meta-estimator by providing as constructor parameter a base_estimator and the list ordered list of target classes to cascade upon. In the fit method of this meta classifier you subslice this data based on those classes and fit clones of the base_estimators for each level and store the resulting sub-classifiers at attribute of the meta classifier.

In the predict method you iterate again over the cascading structure and this time call predict on the underlying sub classifier to slice your predictions and pass those to the next level recursively. You will need a fair deal of numpy fancy indexing ;)

You can git grep base_estimator in the source code to find existing example of meta estimators in the code base (like Bagging, AdaBoost, GridSearchCV...).

like image 89
ogrisel Avatar answered Oct 13 '22 21:10

ogrisel