Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Question about Permutation Importance on LSTM Keras

from keras.wrappers.scikit_learn import KerasClassifier, KerasRegressor   
import eli5
from eli5.sklearn import PermutationImportance

model = Sequential()
model.add(LSTM(units=30,return_sequences= True, input_shape=(X.shape[1],421)))
model.add(Dropout(rate=0.2))
model.add(LSTM(units=30, return_sequences=True))
model.add(LSTM(units=30))
model.add(Dense(units=1, activation='relu'))

perm = PermutationImportance(model, scoring='accuracy',random_state=1).fit(X, y, epochs=500, batch_size=8)
eli5.show_weights(perm, feature_names = X.columns.tolist())

I am running an LSTM just to see the feature importance of my dataset containing 400+ features. I used the Keras scikit-learn wrapper to use eli5's PermutationImportance function. But the code is returning

ValueError: Found array with dim 3. Estimator expected <= 2.

The code runs smoothly if I use model.fit() but can't debug the error of the permutation importance. Anyone know what is wrong?

like image 589
rorivy15 Avatar asked Mar 15 '20 05:03

rorivy15


People also ask

How do you know if a feature is important in Lstm?

LSTM Feature Importance (aka permutation importance) With Neural Networks, we don't have that advantage. However, we can use a technique called Permutation Feature Importance to compute the Feature Importance for any model. Therefore we can compute LSTM Feature Importance using permuation importance.

What is permutation importance?

Feature permutation importance measures the predictive value of a feature for any black box estimator, classifier, or regressor. It does this by evaluating how the prediction error increases when a feature is not available. Any scoring metric can be used to measure the prediction error.


1 Answers

eli5's scikitlearn implementation for determining permutation importance can only process 2d arrays while keras' LSTM layers require 3d arrays. This error is a known issue but there appears to be no solution yet.

I understand this does not really answer your question of getting eli5 to work with LSTM (because it currently can't), but I encountered the same problem and used another library called SHAP to get the feature importance of my LSTM model. Here is some of my code to help you get started:

import shap
DE = shap.DeepExplainer(model, X_train) # X_train is 3d numpy.ndarray
shap_values = DE.shap_values(X_validate_np, check_additivity=False) # X_validate is 3d numpy.ndarray

shap.initjs()
shap.summary_plot(
    shap_values[0], 
    X_validate,
    feature_names=list_of_your_columns_here,
    max_display=50,
    plot_type='bar')

Here is an example of the graph which you can get:

enter image description here

Hope this helps.

like image 111
user5305519 Avatar answered Sep 22 '22 01:09

user5305519