kerasscikit-learnlstmeli5

Question about Permutation Importance on LSTM Keras


from keras.wrappers.scikit_learn import KerasClassifier, KerasRegressor   
import eli5
from eli5.sklearn import PermutationImportance

model = Sequential()
model.add(LSTM(units=30,return_sequences= True, input_shape=(X.shape[1],421)))
model.add(Dropout(rate=0.2))
model.add(LSTM(units=30, return_sequences=True))
model.add(LSTM(units=30))
model.add(Dense(units=1, activation='relu'))

perm = PermutationImportance(model, scoring='accuracy',random_state=1).fit(X, y, epochs=500, batch_size=8)
eli5.show_weights(perm, feature_names = X.columns.tolist())

I am running an LSTM just to see the feature importance of my dataset containing 400+ features. I used the Keras scikit-learn wrapper to use eli5's PermutationImportance function. But the code is returning

ValueError: Found array with dim 3. Estimator expected <= 2.

The code runs smoothly if I use model.fit() but can't debug the error of the permutation importance. Anyone know what is wrong?


Solution

  • eli5's scikitlearn implementation for determining permutation importance can only process 2d arrays while keras' LSTM layers require 3d arrays. This error is a known issue but there appears to be no solution yet.

    I understand this does not really answer your question of getting eli5 to work with LSTM (because it currently can't), but I encountered the same problem and used another library called SHAP to get the feature importance of my LSTM model. Here is some of my code to help you get started:

    import shap
    DE = shap.DeepExplainer(model, X_train) # X_train is 3d numpy.ndarray
    shap_values = DE.shap_values(X_validate_np, check_additivity=False) # X_validate is 3d numpy.ndarray
    
    shap.initjs()
    shap.summary_plot(
        shap_values[0], 
        X_validate,
        feature_names=list_of_your_columns_here,
        max_display=50,
        plot_type='bar')
    
    

    Here is an example of the graph which you can get:

    enter image description here

    Hope this helps.