I am working on neuron importance for ANNs (in a classification setting). A simple baseline is the partial derivative of the model output for the correct class with respect to the given neuron pre-activation. My goal is to get important neurons for a cluster of related inputs. Thus, I am currently just averaging the importances for the samples in each cluster to get the per-cluster importances. The problem is that I get very small partial derivatives for each neuron on average and that the resulting importance scores are not good in the sense that if you remove the most important neurons from the model, the performance of the model does not decrease more than when compared to removing neurons at random. Other methods, such as using the pre-activations as importance scores, work better in that regard. However, experiments from other authors suggest that the partial derivative based approach should work reasonably well in many settings. Thus, I suspect there is a bug in my code.
My current setup is as following. I know the code is not optimal from a performance POV - it is written with clarity in mind. My question is, if there is a mistake in the code to calculate the partial derivatives with respect to the neurons' pre-activations. Assume that the provided layer in the attribute function is a PyTorch Linear layer. Thus the last argument in the hook should be the partial derivative with respect to the pre-activation.
class AttributionMethod(ABC):
def __init__(self, model: nn.Module, data_loader: Dict[int, DataLoader]):
self.model = model
self.data_loader = data_loader
if self.model is not None:
self.model.eval()
@abstractmethod
def attribute(self, layer: Module, latent_class: int) -> np.ndarray:
pass
class GradientAttribution(AttributionMethod):
def attribute(self, layer: Module, latent_class: int) -> np.ndarray:
# Calculates the importance scores of the neurons in the provided layer for the provided cluster
device = next(self.model.parameters()).device
inputs, targets = next(iter(self.data_loader[latent_class]))
inputs = inputs.to(device)
targets = targets.to(device)
inputs.requires_grad = True
attributions = []
def hook_fn(module, grad_input, grad_output):
attributions.append(grad_output[0].detach().cpu())
handle = layer.register_full_backward_hook(hook_fn)
# Forward and backward pass
for i in range(len(inputs)):
input_i = inputs[i].unsqueeze(0) # Add batch dimension
output = self.model(input_i).squeeze()
target_i = targets[i].item()
output[target_i].backward()
# We do not need to zero the grads since only the model param grads are accumulated
handle.remove()
# Concatenate over batch and average per neuron
grads_tensor = torch.cat(attributions, dim=0)
grads_per_neuron = grads_tensor.mean(dim=0).numpy()
return grads_per_neuron
My question is, if there is a mistake in the code to calculate the partial derivatives with respect to the neurons' pre-activations.
Since the question is if there is a mistake in the code, I think the answer is no, and your problem likely lies in the complexity of your data and problem.
You can verify this by using a toy example. Consider a model like this
which reads 3 features such that each feature is fed into a separate tower. Each tower contains only one neuron.
In the end the logits from the various towers and combined and 2 classes are predicted (which can be extended to make it multiclass if you want).
class ToyModel(nn.Module):
def __init__(self, hidden_dim=4):
super().__init__()
self.tower0 = nn.Sequential(nn.Linear(1, hidden_dim), nn.ReLU())
self.tower1 = nn.Sequential(nn.Linear(1, hidden_dim), nn.ReLU())
self.tower2 = nn.Sequential(nn.Linear(1, hidden_dim), nn.ReLU())
self.combiner = nn.Linear(3 * hidden_dim, 2)
def forward(self, x):
x0 = x[:, [0]]
x1 = x[:, [1]]
x2 = x[:, [2]]
h0 = self.tower0(x0)
h1 = self.tower1(x1)
h2 = self.tower2(x2)
h = torch.cat([h0, h1, h2], dim=1)
return self.combiner(h)
Generate artificial training data sucj that the middle (2nd) feature is highly correlated with the label and hence its associated tower's neuron should have high importance.
N = 200
X = np.random.randn(N, 3).astype(np.float32)
y = (X[:, 1] > 0).astype(np.int64)
dataset = TensorDataset(torch.from_numpy(X), torch.from_numpy(y))
loader = DataLoader(dataset, batch_size=16, shuffle=True)
Train it:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ToyModel().to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(10):
for xb, yb in loader:
xb, yb = xb.to(device), yb.to(device)
opt.zero_grad()
pred = model(xb)
loss = loss_fn(pred, yb)
loss.backward()
opt.step()
if (epoch+1) % 5 == 0:
acc = (pred.argmax(1) == yb).float().mean().item()
print(f"Epoch {epoch+1}, Loss={loss.item():.4f}, Acc={acc:.3f}")
Then use your exact code to compute the 3 neutrons' average feature importance, measured by the average partial derivative.
def gradient_attribution(model, layer, inputs, targets):
device = next(model.parameters()).device
inputs = inputs.to(device)
targets = targets.to(device)
# Everything else the same from your code
xb, yb = next(iter(loader))
# using just the linear layers (pre activation) like you did
attr0 = gradient_attribution(model, model.tower0[0], xb, yb)
attr1 = gradient_attribution(model, model.tower1[0], xb, yb)
attr2 = gradient_attribution(model, model.tower2[0], xb, yb)
print("Tower0 importance:", abs(attr0).mean())
print("Tower1 importance:", abs(attr1).mean())
print("Tower2 importance:", abs(attr2).mean())
I see output like this
Tower0 importance: 0.083575
Tower1 importance: 0.3273803 <- gets the highest importance as expected
Tower2 importance: 0.05412347
However, in real world problems (unlike in this toy example), the patterns might be more subtle. You can try to tweak your data and model to create strong patterns like I did to see if more meaningful patterns emerge, which can be a good starting point.