After running a Variance Threshold from Scikit-Learn on a set of data, it removes a couple of features. I feel I'm doing something simple yet stupid, but I'd like to retain the names of the remaining features. The following code:
def VarianceThreshold_selector(data):
selector = VarianceThreshold(.5)
selector.fit(data)
selector = (pd.DataFrame(selector.transform(data)))
return selector
x = VarianceThreshold_selector(data)
print(x)
changes the following data (this is just a small subset of the rows):
Survived Pclass Sex Age SibSp Parch Nonsense
0 3 1 22 1 0 0
1 1 2 38 1 0 0
1 3 2 26 0 0 0
into this (again just a small subset of the rows)
0 1 2 3
0 3 22.0 1 0
1 1 38.0 1 0
2 3 26.0 0 0
Using the get_support method, I know that these are Pclass, Age, Sibsp, and Parch, so I'd rather this return something more like :
Pclass Age Sibsp Parch
0 3 22.0 1 0
1 1 38.0 1 0
2 3 26.0 0 0
Is there an easy way to do this? I'm very new with Scikit Learn, so I'm probably just doing something silly.
If they have low variance, they likely won't improve your model anyway, so it's safe to remove them.
Variance threshold is a simple baseline approach to feature selection. It removes all features whose variance doesn't meet some threshold as it is assumed that features with a higher variance may contain more useful information.
The threshold value to use for feature selection. Features whose absolute importance value is greater or equal are kept while the others are discarded. If “median” (resp. “mean”), then the threshold value is the median (resp. the mean) of the feature importances.
Would something like this help? If you pass it a pandas dataframe, it will get the columns and use get_support
like you mentioned to iterate over the columns list by their indices to pull out only the column headers that met the variance threshold.
>>> df
Survived Pclass Sex Age SibSp Parch Nonsense
0 0 3 1 22 1 0 0
1 1 1 2 38 1 0 0
2 1 3 2 26 0 0 0
>>> from sklearn.feature_selection import VarianceThreshold
>>> def variance_threshold_selector(data, threshold=0.5):
selector = VarianceThreshold(threshold)
selector.fit(data)
return data[data.columns[selector.get_support(indices=True)]]
>>> variance_threshold_selector(df, 0.5)
Pclass Age
0 3 22
1 1 38
2 3 26
>>> variance_threshold_selector(df, 0.9)
Age
0 22
1 38
2 26
>>> variance_threshold_selector(df, 0.1)
Survived Pclass Sex Age SibSp
0 0 3 1 22 1
1 1 1 2 38 1
2 1 3 2 26 0
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With