Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Information Gain calculation with Scikit-learn

I am using Scikit-learn for text classification. I want to calculate the Information Gain for each attribute with respect to a class in a (sparse) document-term matrix.

  • the Information Gain is defined as H(Class) - H(Class | Attribute), where H is the entropy.
  • in weka, this would be calculated with InfoGainAttribute.
  • But I haven't found this measure in scikit-learn.

(It was suggested that the formula above for Information Gain is the same measure as mutual information. This matches also the definition in wikipedia. Is it possible to use a specific setting for mutual information in scikit-learn to accomplish this task?)

like image 932
Roman Purgstaller Avatar asked Oct 15 '17 07:10

Roman Purgstaller


People also ask

How is information gain used in feature selection?

Information gain calculates the reduction in entropy from the transformation of a dataset. It can be used for feature selection by evaluating the Information gain of each variable in the context of the target variable.

What is information gain in decision tree algorithm?

Information Gain. Information gain is a decrease in entropy. It computes the difference between entropy before split and average entropy after split of the dataset based on given attribute values. ID3 (Iterative Dichotomiser) decision tree algorithm uses information gain.


2 Answers

You can use scikit-learn's mutual_info_classif here is an example

from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_selection import mutual_info_classif
from sklearn.feature_extraction.text import CountVectorizer

categories = ['talk.religion.misc',
              'comp.graphics', 'sci.space']
newsgroups_train = fetch_20newsgroups(subset='train',
                                      categories=categories)

X, Y = newsgroups_train.data, newsgroups_train.target
cv = CountVectorizer(max_df=0.95, min_df=2,
                                     max_features=10000,
                                     stop_words='english')
X_vec = cv.fit_transform(X)

res = dict(zip(cv.get_feature_names(),
               mutual_info_classif(X_vec, Y, discrete_features=True)
               ))
print(res)

this will output a dictionary of each attribute, i.e. item in the vocabulary as keys and their information gain as values

here is a sample of the output

{'bible': 0.072327479595571439,
 'christ': 0.057293733680219089,
 'christian': 0.12862867565281702,
 'christians': 0.068511328611810071,
 'file': 0.048056478042481157,
 'god': 0.12252523919766867,
 'gov': 0.053547274485785577,
 'graphics': 0.13044709565039875,
 'jesus': 0.09245436105573257,
 'launch': 0.059882179387444862,
 'moon': 0.064977781072557236,
 'morality': 0.050235104394123153,
 'nasa': 0.11146392824624819,
 'orbit': 0.087254803670582998,
 'people': 0.068118370234354936,
 'prb': 0.049176995204404481,
 'religion': 0.067695617096125316,
 'shuttle': 0.053440976618359261,
 'space': 0.20115901737978983,
 'thanks': 0.060202010019767334}
like image 129
sgDysregulation Avatar answered Oct 25 '22 01:10

sgDysregulation


Here is my proposition to calculate the information gain using pandas:

from scipy.stats import entropy
import pandas as pd
def information_gain(members, split):
    '''
    Measures the reduction in entropy after the split  
    :param v: Pandas Series of the members
    :param split:
    :return:
    '''
    entropy_before = entropy(members.value_counts(normalize=True))
    split.name = 'split'
    members.name = 'members'
    grouped_distrib = members.groupby(split) \
                        .value_counts(normalize=True) \
                        .reset_index(name='count') \
                        .pivot_table(index='split', columns='members', values='count').fillna(0) 
    entropy_after = entropy(grouped_distrib, axis=1)
    entropy_after *= split.value_counts(sort=False, normalize=True)
    return entropy_before - entropy_after.sum()

members = pd.Series(['yellow','yellow','green','green','blue'])
split = pd.Series([0,0,1,1,0])
print (information_gain(members, split))
like image 3
Gaël Bernard Avatar answered Oct 25 '22 02:10

Gaël Bernard