I need to get from my Pytorch AutoEncoder the importance it gives to each input variable. I am working with a tabular data set, no images.
My AutoEncoder is as follows:
class AE(torch.nn.Module):
def __init__(self, input_size, hidden_layer, latent_layer):
super().__init__()
self.encoder = torch.nn.Sequential(
torch.nn.Linear(input_size, hidden_layer),
torch.nn.ReLU(),
torch.nn.Linear(hidden_layer, latent_layer)
)
self.decoder = torch.nn.Sequential(
torch.nn.Linear(latent_layer, hidden_layer),
torch.nn.ReLU(),
torch.nn.Linear(hidden_layer, input_size)
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
To save unnecessary information, I simply call the following function to get my model:
average_loss, model, train_losses, test_losses = fullAE(batch_size=128, input_size=genes_tensor.shape[1],
learning_rate=0.0001, weight_decay=0,
epochs=50, verbose=False, dataset=genes_tensor, betas_value=(0.9, 0.999), train_dataset=genes_tensor_train, test_dataset=genes_tensor_test)
Where "model" is a trained instance of the previous AutoEncoder:
model = AE(input_size=input_size, hidden_layer=int(input_size * 0.75), latent_layer=int(input_size * 0.5)).to(device)
Well now I need to get the importance given by that model to each input variable in my original "genes_tensor" dataset, but I don't know how. I have researched how to do it and found a way to do it with shap software:
e = shap.DeepExplainer(model, genes_tensor)
shap_values = e.shap_values(
genes_tensor
)
shap.summary_plot(shap_values,genes_tensor,feature_names=features)
The problem with this implementation is the following: 1) I don't know if what I am actually doing is correct. 2) It takes forever to finish, since the dataset contains 950 samples, I have tried to do it with only 1 sample and it takes long enough. The result using a single sample is as follows:
I have seen that there are other options to obtain the importance of the input variables like Captum, but Captum only allows to know the importance in Neural Networks with a single output neuron, in my case there are many.
The options for AEs or VAEs that I have seen on github do not work for me since they use concrete cases, and especially images always, for example:
https://github.com/peterparity/PDE-VAE-pytorch
https://github.com/FengNiMa/VAE-TracIn-pytorch
Is my shap implementation correct?
Edit:
I have run the shap code with only 4 samples and get the following result:
I don't understand why it's not the typical shap summary_plot plot that appears everywhere.
I have been looking at the shap documentation, and it is because my model is multi-output by having more than one neuron at the output.
Not commenting much on SHAP below, but I have some thoughts on potential alternatives. Example code at the end.
It takes forever to finish, since the dataset contains 950 samples, I have tried to do it with only 1 sample and it takes long enough [...] Should I try other methods?
Since SHAP is taking so long, I think it's worth considering other techniques if you think they can provide useful information which you can iterate on more quickly.
One approach is to run permutation importance tests (example code at end). Start by training a 'good' reference model, and getting the model's reconstruction and reconstruction error using the original data. Then, for each feature_i
feature_i
This information will allow you to plot feature vs. change in recon, or feature vs. change in recon error. The first plot tells you how each feature impacts the model's output, and can be viewed as an approximation of SHAP (though I view it as a distinct and useful method in its own right). The second plot tells you how each feature impacts reconstruction accuracy. This method is relatively fast as you only need to train the model once.
A limitation of this method is that if features are highly correlated, permutation tests can underestimate or miss a feature's importance (SHAP doesn't). There are ways of mitigating this, such as assessing correlations in advance and removing or grouping related ones.
An alternative way of assessing feature importance for an autoencoder is to record the latent representation of each sample. You can run a mutual information analysis to see the strength of association between a feature and the latent space representation. Some features might explain more of the compressed representation than others, suggesting a relative importance.
Other techniques could look at the size of the weight learnt for each feature (perhaps in combination with a sparsity penalty), or activation sizes.
For any given method, consider running it on just a portion of the dataset in order to save time, or training for only a few epochs. The results will be more approximate, but may be good enough for assessing relative feature importances.
To minimise overfitting, you might want to run the fitting on part of the data, and then get your recons and recon errors using an unseen validation sample.
The code below trains an autoencoder on petal features and runs a permutation test on the features. In this example some the features were highly correlated, and since I didn't handle that I'm not going to rely on the results below. The figures are just illustrative of what the code does.
Imports and prepare data
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
#
# Load data
#
from sklearn.datasets import load_iris
from sklearn.preprocessing import QuantileTransformer
from sklearn.model_selection import train_test_split
X, y = load_iris(return_X_y=True, as_frame=True)
y.name = 'species'
X = pd.concat([X, y.to_frame()], axis=1)
n_features = X.shape[1]
trn_val_ix, tst_ix = train_test_split(range(len(X)), test_size=0.1, random_state=0)
trn_ix, val_ix = train_test_split(trn_val_ix, test_size=0.2, random_state=0)
X_trn, X_val, X_tst = X.iloc[trn_ix], X.iloc[val_ix], X.iloc[tst_ix]
#To numpy arrays, and scale
scaler = QuantileTransformer(output_distribution='uniform', n_quantiles=10, random_state=0).fit(X_trn.values)
X_trn_a, X_val_a, X_tst_a = [scaler.transform(data.values) for data in [X_trn, X_val, X_tst]]
# To tensors
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
X_trn_t, X_val_t, X_tst_t = [torch.Tensor(data).float()
for data in [X_trn_a, X_val_a, X_tst_a]]
Define a simple autoencoder and a training loop:
#
#Define a simple autoencoder
#
def make_autoencoder(latent_dim_size=3, hidden_size=5):
activation = nn.Tanh
encoder = nn.Sequential(
nn.Linear(n_features, n_features),
activation(),
nn.Linear(n_features, hidden_size),
activation(),
nn.Linear(hidden_size, latent_dim_size),
)
decoder = nn.Sequential(
activation(),
nn.Linear(latent_dim_size, hidden_size),
activation(),
nn.Linear(hidden_size, n_features),
activation(),
nn.Linear(n_features, n_features)
)
autoencoder = nn.Sequential(encoder, decoder)
return autoencoder
print('Model size:', sum([p.numel() for p in make_autoencoder().parameters()]))
@torch.no_grad()
def eval_metric(model, loader):
model.eval()
cum_rmse_pct = 0
for X_minibatch in loader:
output = model(X_minibatch)
rmse_pct = (output - X_minibatch).norm(dim=1) / X_minibatch.norm(dim=1) * 100
cum_rmse_pct += rmse_pct.sum()
return (cum_rmse_pct / loader.dataset.shape[0]).item()
def train(model, loader, optimiser, n_epochs=1, loss_fn=nn.functional.mse_loss):
metrics = {'train_loss': [], 'train_metric': [], 'val_metric': []}
for epoch in range(n_epochs):
model.train()
cum_loss = 0
for minibatch, X_minibatch in enumerate(loader):
output = model(X_minibatch)
loss = loss_fn(output, X_minibatch)
optimiser.zero_grad()
loss.backward()
optimiser.step()
cum_loss += loss.item() * len(X_minibatch)
#Record metrics
metrics['train_loss'].append(cum_loss / loader.dataset.shape[0])
metrics['train_metric'].append(eval_metric(autoencoder, train_loader))
metrics['val_metric'].append(eval_metric(autoencoder, val_loader))
#Print epoch average loss
if (epoch + 1) % 20 == 0 or (epoch == n_epochs - 1):
print(
f'[epoch {epoch + 1:>3d}][minibatch {minibatch + 1:>3d}/{len(loader):>3d}]',
f'train loss {metrics["train_loss"][-1]:>6.3f} |',
f'train metric {metrics["train_metric"][-1]:>6.2f} | '
f'val metric {metrics["val_metric"][-1]:>6.2f}'
)
return metrics
Train the model. Calling it good at ~13% reconstruction error.
#Register optimiser and define data loaders
batch_size = 4
n_epochs = 200
torch.manual_seed(0)
autoencoder = make_autoencoder()
optimiser = optim.NAdam(autoencoder.parameters())
# optimiser = optim.SGD(autoencoder.parameters(), lr=1e-3, momentum=0.9)
train_loader = DataLoader(X_trn_t, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(X_val_t, batch_size=batch_size, num_workers=2)
history = train(autoencoder, train_loader, optimiser, n_epochs=n_epochs)
f, ax = plt.subplots(figsize=(10, 3))
ax.plot(history['train_loss'], linestyle='--', label='loss')
ax2 = ax.twinx()
ax2.plot(history['train_metric'], label='train metric')
ax2.plot(history['val_metric'], label='val metric')
ax.set_xlabel('epoch')
ax.set_ylabel('loss')
ax2.set_ylabel('rmse %')
f.legend()
[epoch 20][minibatch 27/ 27] train loss 0.021 | train metric 26.33 | val metric 41.81
[epoch 40][minibatch 27/ 27] train loss 0.016 | train metric 23.17 | val metric 36.84
[epoch 60][minibatch 27/ 27] train loss 0.006 | train metric 13.34 | val metric 17.87
[epoch 80][minibatch 27/ 27] train loss 0.006 | train metric 12.97 | val metric 16.94
[epoch 100][minibatch 27/ 27] train loss 0.005 | train metric 12.60 | val metric 16.59
[epoch 120][minibatch 27/ 27] train loss 0.005 | train metric 12.32 | val metric 16.10
[epoch 140][minibatch 27/ 27] train loss 0.005 | train metric 11.80 | val metric 15.49
[epoch 160][minibatch 27/ 27] train loss 0.004 | train metric 10.89 | val metric 14.62
[epoch 180][minibatch 27/ 27] train loss 0.004 | train metric 10.52 | val metric 14.28
[epoch 200][minibatch 27/ 27] train loss 0.003 | train metric 9.63 | val metric 13.54
On that trained model, run permutation tests for each feature, and plot the results. Plotted are the model's drop in performance, and a bar plot of normalised results (which can be interpreted as feature importances). These results are shown at the start of this example.
Permutation tests:
rng = np.random.default_rng(0)
#Model's val score before permutation
unpermuted_rmse_pct = history['val_metric'][-1]
n_repeats = 50 #number of trials per feature
permutation_metrics = np.empty([n_features, n_repeats])
#Shuffle each feature in turn, and get model's score
for col_idx, col_name in enumerate(X_val.columns):
X_val_perm = X_val_t.clone()
for repeat in range(n_repeats):
shuffled_ixs = rng.permutation(len(X_val))
X_val_perm[:, col_idx] = X_val_t[shuffled_ixs, col_idx]
val_loader = DataLoader(X_val_perm, batch_size=batch_size, shuffle=True)
permutation_metrics[col_idx, repeat] = eval_metric(autoencoder, val_loader)
#Convert to change in score compared to unpermuted data
permutation_df = pd.DataFrame(permutation_metrics.T, columns=X_val.columns) - unpermuted_rmse_pct
Plotting:
#Box plot of change in score
import seaborn as sns
permutation_melt = permutation_df.melt(var_name='feature', value_name='permuted_rmse_pct')
sns.boxplot(permutation_melt, y='feature', x='permuted_rmse_pct')
ax = sns.stripplot(permutation_melt, y='feature', x='permuted_rmse_pct', marker='.', color='tab:red')
ax.set_xlabel('drop in performance')
ax.set_ylabel('permuted feature')
ax.figure.set_size_inches(8, 2.5)
plt.show()
#Bar chart of feature importances
normalised_scores = permutation_df.mean(axis=0) / permutation_df.mean(axis=0).sum() #scores 0-1
ax = sns.barplot(normalised_scores, color='tab:purple')
ax.tick_params(axis='x', rotation=45)
ax.set_xlabel('feature')
ax.set_ylabel('feature importance')
ax.figure.set_size_inches(4, 3)