pythonmachine-learningdeep-learningpytorchneural-network

Does using torch.where to threshold a tensor detach it tensor from the computational graph?


I'm writing a custom loss function in PyTorch for multiclass semantic segmentation. One part of this function is thresholding select channels from the tensor, which are indicated with tracker_index.

The last part of the function that is a part of the computational graph is the channel_tensor, and if I comment out the line where torch.where is applied, everything runs smoothly. I've tried setting 1 and 0 to float32 tensors and ensured that they are on the same device as the channel_tensor, which leads me to believe that eighter thresholding is not differentiable, so cannot be a part of the loss function, or torch.where will always detach the tensor from the computational graph. Please advise.

channel_tensor =torch.select(
     segmentation_output,
     dim=-3,
     index=tracker_index
)
channels[tracker_index]= torch.where(channel_tensor > self.threshold, torch.tensor(1, device=channel_tensor.device, dtype=torch.float32), torch.tensor(0, device=channel_tensor.device, dtype=torch.float32))

Solution

  • No, but ...

    torch.where(...) does not detach anything from the computational graph.

    torch.where(cond, a, b) has the same gradient as a where cond is True and the same as b where cond is False

    (so in essence, if c = torch.where(cond, a, b), c.grad is torch.where(cond, a.grad, b.grad))

    In your case though, a and b are constants so all those gradients are 0, which is effectively cutting the results from the graph.

    You say your operation is "thresholding", but that is not what you are doing!

    Thresholding would be keeping the value unless it is above (or below) some threshold. What you are doing is setting the values below the threshold to 0 and the values above to 1, which is a Heaviside step function. It is differentiable almost everywhere, but its gradient is always 0 when defined (so unusable for optimization purposes)

    Fix

    You may want to replace that heaviside function with its differentiable approximation, the sigmoid

    The code would be something like this

    channels[tracker_index]= torch.sigmoid(channel_tensor - threshold)
    

    Note also that if you are looking for a loss for this data, you may want to look towards binary cross-entropy, in which case there is a (more stable) version in pytorch that takes the raw logits instead of the output of the sigmoid