so I have 4 methods to calculate dice loss and 3 of them are returning the same results, so I can conclude that 1 of them is calculating it wrong, but I would to confirm it with you guys:
import torch
torch.manual_seed(0)
inputs = torch.rand((3,1,224,224))
target = torch.rand((3,1,224,224))
Method 1: flatten tensors
def method1(inputs, target):
inputs = inputs.reshape( -1)
target = target.reshape( -1)
intersection = (inputs * target).sum()
union = inputs.sum() + target.sum()
dice = (2. * intersection) / (union + 1e-8)
dice = dice.sum()
print("method1", dice)
Method 2: flatten tensors except for batch size, sum all dims
def method2(inputs, target):
num = target.shape[0]
inputs = inputs.reshape(num, -1)
target = target.reshape(num, -1)
intersection = (inputs * target).sum()
union = inputs.sum() + target.sum()
dice = (2. * intersection) / (union + 1e-8)
dice = dice.sum()/num
print("method2", dice)
Method 3: flatten tensors except for batch size, sum dim 1
def method3(inputs, target):
num = target.shape[0]
inputs = inputs.reshape(num, -1)
target = target.reshape(num, -1)
intersection = (inputs * target).sum(1)
union = inputs.sum(1) + target.sum(1)
dice = (2. * intersection) / (union + 1e-8)
dice = dice.sum()/num
print("method3", dice)
Method 4: don't flatten tensors
def method4(inputs, target):
intersection = (inputs * target).sum()
union = inputs.sum() + target.sum()
dice = (2. * intersection) / (union + 1e-8)
print("method4", dice)
method1(inputs, target)
method2(inputs, target)
method3(inputs, target)
method4(inputs, target)
method 1,3 and 4 print: 0.5006 method 2 print: 0.1669
and it makes sense, since I am flattening the inputs and targets on 3 dimensions leaving out batch size, and then I am summing all 2 dimensions that result from the flattening instead of just dim 1
Method 4 seems to be the most optimized one
First, you need to decide what dice score you report: the dice score of all samples in the batch (methods 1,2 and 4) or the averaged dice score of each sample in the batch (method 3).
If I'm not mistaken, you want to use method 3 - you want to optimize the dice score of each of the samples in the batch and not a "global" dice score: Suppose you have one "difficult" sample in an "easy" batch. The misclassified pixels of the "difficult" sample will be negligible w.r.t all other pixels. But if you look at the dice score of each sample separately then the dice score of the "difficult" sample will not be negligible.