deep-learningpytorchneural-networkautoencodershap

Feature Importance of a Pytorch AutoEncoder


I need to get from my Pytorch AutoEncoder the importance it gives to each input variable. I am working with a tabular data set, no images.

My AutoEncoder is as follows:

class AE(torch.nn.Module):
    def __init__(self, input_size, hidden_layer, latent_layer):
        super().__init__()

        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(input_size, hidden_layer),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_layer, latent_layer)
        )

        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(latent_layer, hidden_layer),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_layer, input_size)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

To save unnecessary information, I simply call the following function to get my model:

average_loss, model, train_losses, test_losses = fullAE(batch_size=128, input_size=genes_tensor.shape[1],
                                 learning_rate=0.0001, weight_decay=0,
                                 epochs=50, verbose=False, dataset=genes_tensor, betas_value=(0.9, 0.999), train_dataset=genes_tensor_train, test_dataset=genes_tensor_test)

Where "model" is a trained instance of the previous AutoEncoder:

model = AE(input_size=input_size, hidden_layer=int(input_size * 0.75), latent_layer=int(input_size * 0.5)).to(device)

Well now I need to get the importance given by that model to each input variable in my original "genes_tensor" dataset, but I don't know how. I have researched how to do it and found a way to do it with shap software:

e = shap.DeepExplainer(model, genes_tensor)

shap_values = e.shap_values(
    genes_tensor
)

shap.summary_plot(shap_values,genes_tensor,feature_names=features)

The problem with this implementation is the following: 1) I don't know if what I am actually doing is correct. 2) It takes forever to finish, since the dataset contains 950 samples, I have tried to do it with only 1 sample and it takes long enough. The result using a single sample is as follows:

I have seen that there are other options to obtain the importance of the input variables like Captum, but Captum only allows to know the importance in Neural Networks with a single output neuron, in my case there are many.

The options for AEs or VAEs that I have seen on github do not work for me since they use concrete cases, and especially images always, for example:

https://github.com/peterparity/PDE-VAE-pytorch

https://github.com/FengNiMa/VAE-TracIn-pytorch

Is my shap implementation correct?

Edit:

I have run the shap code with only 4 samples and get the following result:

shap with 4 samples

I don't understand why it's not the typical shap summary_plot plot that appears everywhere.

I have been looking at the shap documentation, and it is because my model is multi-output by having more than one neuron at the output.


Solution

  • Not commenting much on SHAP below, but I have some thoughts on potential alternatives. Example code at the end.

    It takes forever to finish, since the dataset contains 950 samples, I have tried to do it with only 1 sample and it takes long enough [...] Should I try other methods?

    Since SHAP is taking so long, I think it's worth considering other techniques if you think they can provide useful information which you can iterate on more quickly.

    One approach is to run permutation importance tests (example code at end). Start by training a 'good' reference model, and getting the model's reconstruction and reconstruction error using the original data. Then, for each feature_i

    This information will allow you to plot feature vs. change in recon, or feature vs. change in recon error. The first plot tells you how each feature impacts the model's output, and can be viewed as an approximation of SHAP (though I view it as a distinct and useful method in its own right). The second plot tells you how each feature impacts reconstruction accuracy. This method is relatively fast as you only need to train the model once.

    A limitation of this method is that if features are highly correlated, permutation tests can underestimate or miss a feature's importance (SHAP doesn't). There are ways of mitigating this, such as assessing correlations in advance and removing or grouping related ones.

    An alternative way of assessing feature importance for an autoencoder is to record the latent representation of each sample. You can run a mutual information analysis to see the strength of association between a feature and the latent space representation. Some features might explain more of the compressed representation than others, suggesting a relative importance.

    Other techniques could look at the size of the weight learnt for each feature (perhaps in combination with a sparsity penalty), or activation sizes.

    For any given method, consider running it on just a portion of the dataset in order to save time, or training for only a few epochs. The results will be more approximate, but may be good enough for assessing relative feature importances.

    To minimise overfitting, you might want to run the fitting on part of the data, and then get your recons and recon errors using an unseen validation sample.


    The code below trains an autoencoder on petal features and runs a permutation test on the features. In this example some the features were highly correlated, and since I didn't handle that I'm not going to rely on the results below. The figures are just illustrative of what the code does.

    enter image description here

    Imports and prepare data

    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    
    #
    # Load data
    #
    from sklearn.datasets import load_iris
    from sklearn.preprocessing import QuantileTransformer
    from sklearn.model_selection import train_test_split
    
    X, y  = load_iris(return_X_y=True, as_frame=True)
    y.name = 'species'
    X = pd.concat([X, y.to_frame()], axis=1)
    n_features = X.shape[1]
    
    trn_val_ix, tst_ix = train_test_split(range(len(X)), test_size=0.1, random_state=0)
    trn_ix, val_ix = train_test_split(trn_val_ix, test_size=0.2, random_state=0)
    
    X_trn, X_val, X_tst = X.iloc[trn_ix], X.iloc[val_ix], X.iloc[tst_ix]
    
    #To numpy arrays, and scale
    scaler = QuantileTransformer(output_distribution='uniform', n_quantiles=10, random_state=0).fit(X_trn.values)
    X_trn_a, X_val_a, X_tst_a = [scaler.transform(data.values) for data in [X_trn, X_val, X_tst]]
    
    # To tensors
    import torch
    from torch import nn
    from torch import optim
    from torch.utils.data import DataLoader
    
    X_trn_t, X_val_t, X_tst_t = [torch.Tensor(data).float()
                                 for data in [X_trn_a, X_val_a, X_tst_a]]
    

    Define a simple autoencoder and a training loop:

    #
    #Define a simple autoencoder
    #
    def make_autoencoder(latent_dim_size=3, hidden_size=5):
        activation = nn.Tanh
        encoder = nn.Sequential(
            nn.Linear(n_features, n_features),
            activation(),
            nn.Linear(n_features, hidden_size),
            activation(),
            nn.Linear(hidden_size, latent_dim_size),
        )
    
        decoder = nn.Sequential(
            activation(),
            nn.Linear(latent_dim_size, hidden_size),
            activation(),
            nn.Linear(hidden_size, n_features),
            activation(),
            nn.Linear(n_features, n_features)
        )
    
        autoencoder = nn.Sequential(encoder, decoder)
        return autoencoder
    
    print('Model size:', sum([p.numel() for p in make_autoencoder().parameters()]))
    
    @torch.no_grad()
    def eval_metric(model, loader):
        model.eval()
    
        cum_rmse_pct = 0
        for X_minibatch in loader:
            output = model(X_minibatch)
            rmse_pct = (output - X_minibatch).norm(dim=1) / X_minibatch.norm(dim=1) * 100
            cum_rmse_pct += rmse_pct.sum()
        
        return (cum_rmse_pct / loader.dataset.shape[0]).item()
    
    def train(model, loader, optimiser, n_epochs=1, loss_fn=nn.functional.mse_loss):
        metrics = {'train_loss': [], 'train_metric': [], 'val_metric': []}
        
        for epoch in range(n_epochs):
            model.train()
            cum_loss = 0
            
            for minibatch, X_minibatch in enumerate(loader):
                output = model(X_minibatch)
                
                loss = loss_fn(output, X_minibatch)
                optimiser.zero_grad()
                loss.backward()
                optimiser.step()
                
                cum_loss += loss.item() * len(X_minibatch)
            
            #Record metrics
            metrics['train_loss'].append(cum_loss / loader.dataset.shape[0])
            metrics['train_metric'].append(eval_metric(autoencoder, train_loader))
            metrics['val_metric'].append(eval_metric(autoencoder, val_loader))
            
            #Print epoch average loss
            if (epoch + 1) % 20 == 0 or (epoch == n_epochs - 1):
                print(
                    f'[epoch {epoch + 1:>3d}][minibatch {minibatch + 1:>3d}/{len(loader):>3d}]',
                    f'train loss {metrics["train_loss"][-1]:>6.3f} |',
                    f'train metric {metrics["train_metric"][-1]:>6.2f} | '
                    f'val metric {metrics["val_metric"][-1]:>6.2f}'
                )
        
        return metrics
    

    Train the model. Calling it good at ~13% reconstruction error.

    #Register optimiser and define data loaders
    batch_size = 4
    n_epochs = 200
    
    torch.manual_seed(0)
    autoencoder = make_autoencoder()
    optimiser = optim.NAdam(autoencoder.parameters())
    # optimiser = optim.SGD(autoencoder.parameters(), lr=1e-3, momentum=0.9)
    
    train_loader = DataLoader(X_trn_t, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(X_val_t, batch_size=batch_size, num_workers=2)
    
    history = train(autoencoder, train_loader, optimiser, n_epochs=n_epochs)
    
    f, ax = plt.subplots(figsize=(10, 3))
    ax.plot(history['train_loss'], linestyle='--', label='loss')
    ax2 = ax.twinx()
    ax2.plot(history['train_metric'], label='train metric')
    ax2.plot(history['val_metric'], label='val metric')
    
    ax.set_xlabel('epoch')
    ax.set_ylabel('loss')
    ax2.set_ylabel('rmse %')
    f.legend()
    

    enter image description here

    [epoch  20][minibatch  27/ 27] train loss  0.021 | train metric  26.33 | val metric  41.81
    [epoch  40][minibatch  27/ 27] train loss  0.016 | train metric  23.17 | val metric  36.84
    [epoch  60][minibatch  27/ 27] train loss  0.006 | train metric  13.34 | val metric  17.87
    [epoch  80][minibatch  27/ 27] train loss  0.006 | train metric  12.97 | val metric  16.94
    [epoch 100][minibatch  27/ 27] train loss  0.005 | train metric  12.60 | val metric  16.59
    [epoch 120][minibatch  27/ 27] train loss  0.005 | train metric  12.32 | val metric  16.10
    [epoch 140][minibatch  27/ 27] train loss  0.005 | train metric  11.80 | val metric  15.49
    [epoch 160][minibatch  27/ 27] train loss  0.004 | train metric  10.89 | val metric  14.62
    [epoch 180][minibatch  27/ 27] train loss  0.004 | train metric  10.52 | val metric  14.28
    [epoch 200][minibatch  27/ 27] train loss  0.003 | train metric   9.63 | val metric  13.54
    

    On that trained model, run permutation tests for each feature, and plot the results. Plotted are the model's drop in performance, and a bar plot of normalised results (which can be interpreted as feature importances). These results are shown at the start of this example.

    Permutation tests:

    rng = np.random.default_rng(0)
    
    #Model's val score before permutation
    unpermuted_rmse_pct = history['val_metric'][-1]
    
    n_repeats = 50 #number of trials per feature
    permutation_metrics = np.empty([n_features, n_repeats])
    
    #Shuffle each feature in turn, and get model's score
    for col_idx, col_name in enumerate(X_val.columns):
        X_val_perm = X_val_t.clone()
        
        for repeat in range(n_repeats):
            shuffled_ixs = rng.permutation(len(X_val))
            X_val_perm[:, col_idx] = X_val_t[shuffled_ixs, col_idx]
            
            val_loader = DataLoader(X_val_perm, batch_size=batch_size, shuffle=True)
            permutation_metrics[col_idx, repeat] = eval_metric(autoencoder, val_loader)
    
    #Convert to change in score compared to unpermuted data
    permutation_df = pd.DataFrame(permutation_metrics.T, columns=X_val.columns) - unpermuted_rmse_pct
    

    Plotting:

    #Box plot of change in score
    import seaborn as sns
    permutation_melt = permutation_df.melt(var_name='feature', value_name='permuted_rmse_pct')
    sns.boxplot(permutation_melt, y='feature', x='permuted_rmse_pct')
    ax = sns.stripplot(permutation_melt, y='feature', x='permuted_rmse_pct', marker='.', color='tab:red')
    
    ax.set_xlabel('drop in performance')
    ax.set_ylabel('permuted feature')
    ax.figure.set_size_inches(8, 2.5)
    plt.show()
    
    #Bar chart of feature importances
    normalised_scores = permutation_df.mean(axis=0) / permutation_df.mean(axis=0).sum() #scores 0-1
    ax = sns.barplot(normalised_scores, color='tab:purple')
    ax.tick_params(axis='x', rotation=45)
    ax.set_xlabel('feature')
    ax.set_ylabel('feature importance')
    ax.figure.set_size_inches(4, 3)