pythonmachine-learningdeep-learningpytorch

Calculating the partial derivative of PyTorch model output with respect to neurons pre-activation


I am working on neuron importance for ANNs (in a classification setting). A simple baseline is the partial derivative of the model output for the correct class with respect to the given neuron pre-activation. My goal is to get important neurons for a cluster of related inputs. Thus, I am currently just averaging the importances for the samples in each cluster to get the per-cluster importances. The problem is that I get very small partial derivatives for each neuron on average and that the resulting importance scores are not good in the sense that if you remove the most important neurons from the model, the performance of the model does not decrease more than when compared to removing neurons at random. Other methods, such as using the pre-activations as importance scores, work better in that regard. However, experiments from other authors suggest that the partial derivative based approach should work reasonably well in many settings. Thus, I suspect there is a bug in my code.

My current setup is as following. I know the code is not optimal from a performance POV - it is written with clarity in mind. My question is, if there is a mistake in the code to calculate the partial derivatives with respect to the neurons' pre-activations. Assume that the provided layer in the attribute function is a PyTorch Linear layer. Thus the last argument in the hook should be the partial derivative with respect to the pre-activation.

class AttributionMethod(ABC):
    def __init__(self, model: nn.Module, data_loader: Dict[int, DataLoader]):
        self.model = model
        self.data_loader = data_loader
        if self.model is not None:
            self.model.eval()

    @abstractmethod
    def attribute(self, layer: Module, latent_class: int) -> np.ndarray:
        pass


class GradientAttribution(AttributionMethod):
    def attribute(self, layer: Module, latent_class: int) -> np.ndarray:
        # Calculates the importance scores of the neurons in the provided layer for the provided cluster
        device = next(self.model.parameters()).device

        inputs, targets = next(iter(self.data_loader[latent_class]))
        inputs = inputs.to(device)
        targets = targets.to(device)
        inputs.requires_grad = True

        attributions = []

        def hook_fn(module, grad_input, grad_output):
            attributions.append(grad_output[0].detach().cpu())

        handle = layer.register_full_backward_hook(hook_fn)

        # Forward and backward pass
        for i in range(len(inputs)):
            input_i = inputs[i].unsqueeze(0)  # Add batch dimension
            output = self.model(input_i).squeeze()
            target_i = targets[i].item()

            output[target_i].backward()
            # We do not need to zero the grads since only the model param grads are accumulated

        handle.remove()

        # Concatenate over batch and average per neuron
        grads_tensor = torch.cat(attributions, dim=0)
        grads_per_neuron = grads_tensor.mean(dim=0).numpy()
        return grads_per_neuron

Solution

  • My question is, if there is a mistake in the code to calculate the partial derivatives with respect to the neurons' pre-activations.

    Since the question is if there is a mistake in the code, I think the answer is no, and your problem likely lies in the complexity of your data and problem.

    You can verify this by using a toy example. Consider a model like this

    3 tower model

    which reads 3 features such that each feature is fed into a separate tower. Each tower contains only one neuron.

    In the end the logits from the various towers and combined and 2 classes are predicted (which can be extended to make it multiclass if you want).

    class ToyModel(nn.Module):
        def __init__(self, hidden_dim=4):
            super().__init__()
            self.tower0 = nn.Sequential(nn.Linear(1, hidden_dim), nn.ReLU())
            self.tower1 = nn.Sequential(nn.Linear(1, hidden_dim), nn.ReLU())
            self.tower2 = nn.Sequential(nn.Linear(1, hidden_dim), nn.ReLU())
            self.combiner = nn.Linear(3 * hidden_dim, 2)
    
        def forward(self, x):
            x0 = x[:, [0]]
            x1 = x[:, [1]]
            x2 = x[:, [2]]
            h0 = self.tower0(x0)
            h1 = self.tower1(x1)
            h2 = self.tower2(x2)
            h = torch.cat([h0, h1, h2], dim=1)
            return self.combiner(h)
    

    Generate artificial training data sucj that the middle (2nd) feature is highly correlated with the label and hence its associated tower's neuron should have high importance.

    N = 200
    X = np.random.randn(N, 3).astype(np.float32)
    y = (X[:, 1] > 0).astype(np.int64)
    
    dataset = TensorDataset(torch.from_numpy(X), torch.from_numpy(y))
    loader = DataLoader(dataset, batch_size=16, shuffle=True)
    

    Train it:

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ToyModel().to(device)
    
    opt = torch.optim.Adam(model.parameters(), lr=0.01)
    loss_fn = nn.CrossEntropyLoss()
    
    for epoch in range(10):
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad()
            pred = model(xb)
            loss = loss_fn(pred, yb)
            loss.backward()
            opt.step()
        if (epoch+1) % 5 == 0:
            acc = (pred.argmax(1) == yb).float().mean().item()
            print(f"Epoch {epoch+1}, Loss={loss.item():.4f}, Acc={acc:.3f}")
    

    Then use your exact code to compute the 3 neutrons' average feature importance, measured by the average partial derivative.

    def gradient_attribution(model, layer, inputs, targets):
        device = next(model.parameters()).device
        inputs = inputs.to(device)
        targets = targets.to(device)
        # Everything else the same from your code
    
    xb, yb = next(iter(loader))
    # using just the linear layers (pre activation) like you did
    attr0 = gradient_attribution(model, model.tower0[0], xb, yb)
    attr1 = gradient_attribution(model, model.tower1[0], xb, yb)
    attr2 = gradient_attribution(model, model.tower2[0], xb, yb)
    
    print("Tower0 importance:", abs(attr0).mean())
    print("Tower1 importance:", abs(attr1).mean())
    print("Tower2 importance:", abs(attr2).mean())
    

    I see output like this

    Tower0 importance: 0.083575
    Tower1 importance: 0.3273803 <- gets the highest importance as expected
    Tower2 importance: 0.05412347
    

    However, in real world problems (unlike in this toy example), the patterns might be more subtle. You can try to tweak your data and model to create strong patterns like I did to see if more meaningful patterns emerge, which can be a good starting point.