This paper proposes a medical image segmentation hybrid CNN - Transformer model for segmenting organs and lesions in medical images simultaneously. Their model has two output branches, one to output organ mask, and the other to output lesion mask. Now they describe the testing process as follows:
In order to compare the performance of our approach with the state- of-the-art approaches, the following evaluation metrics have been used: F1-score (F1-S), Dice score (D-S), Intersection Over Union (IoU), and HD95, which are defined as follows:
where T P is True Positives, T N is True Negatives, F P is False Positives,and F N is False Negatives, all associated with the segmentation classes of the test images. The Dice score is a macro metric, which is calculated for N testing images as follow:
where TPi, FPi and FNi are True Positives, True Negatives, False. Positives and False Negative for the ith image, respectively.
I am confused regarding how to implement those metrics (excluding HD95) like in this paper, what I understand is that to compute TP, FP, and FN for f1-score and IoU, I need to aggregate those 3 quantities (TP, FP, and FN) across all the samples in the test set for the two outputs (lesion and organ), and the aggregation is a sum operation. So for example to calculate the TP, I need to calculate it for every output of every sample and sum this TP. Then repeat this for calculating the TP for every sample in a similar manner and then add all those TPs to get the overall TP. Then I do the same for FP and FN and then plug them in the formulas.
I am not sure if my understanding is correct or not. For Dice score, I need to calculate it for every output separately and then average them? I am not sure about that, so I accessed the GitHub for this paper. The model is defined here, and the coding for the testing procedure is defined here. The used framework is PyTorch. I don't have any knowledge regarding PyTorch, so still I can't understand how these metrics have been implemented, and hence, I cant confirm if my understanding is correct or not. So please can somebody explain the logic used to implement these metrics.
Edit 1 : I went through the code for calculating TP,FP, and FN in train_test_DTrAttUnet_BinarySegmentation.py
:
TP += np.sum(((preds == 1).astype(int) +
(yy == 1).astype(int)) == 2)
TN += np.sum(((preds == 0).astype(int) +
(yy == 0).astype(int)) == 2)
FP += np.sum(((preds == 1).astype(int) +
(yy == 0).astype(int)) == 2)
FN += np.sum(((preds == 0).astype(int) +
(yy == 1).astype(int)) == 2)
It seems like they were doing the forward pass using a for loop and then accumulating the these quantities, and after this loop they calculate the metrics:
F1score = TP / (TP + ((1/2)*(FP+FN)) + 1e-8)
IoU = TP / (TP+FP+FN)
So this means that they are accumulating the TP,FP and FN through all the images for both outputs and then they calculate the metrics, Is that correct ? For Dice Score it seems tricky for me, they still inside the loop calculate some quantities :
for idice in range(preds.shape[0]):
dice_scores += (2 * (preds[idice] * yy[idice]).sum()) / (
(preds[idice] + yy[idice]).sum() + 1e-8
)
predss = np.logical_not(preds).astype(int)
yyy = np.logical_not(yy).astype(int)
for idice in range(preds.shape[0]):
dice_sc1 = (2 * (preds[idice] * yy[idice]).sum()) / (
(preds[idice] + yy[idice]).sum() + 1e-8
)
dice_sc2 = (2 * (predss[idice] * yyy[idice]).sum()) / (
(predss[idice] + yyy[idice]).sum() + 1e-8
)
dice_scores2 += (dice_sc1 + dice_sc2) / 2
Then at the end of the loop :
epoch_dise = dice_scores/len(dataloader.dataset)
epoch_dise2 = dice_scores2/len(dataloader.dataset)
Still, I cant understand what is going on for Dice Score.
Disclaimers:
Anyway, let's break down their code (maybe put the code sample side by side with the explanations below it):
dice_scores, dice_scores2, TP, TN, FP, FN = 0, 0, 0, 0, 0, 0
for batch in tqdm(dataloader):
x, y, _, _ = batch
outputs, _ = model(x)
preds = segm(outputs) > 0.5
yy = y > 0.5
TP += np.sum(((preds == 1) + (yy == 1)) == 2)
TN += np.sum(((preds == 0) + (yy == 0)) == 2)
FP += np.sum(((preds == 1) + (yy == 0)) == 2)
FN += np.sum(((preds == 0) + (yy == 1)) == 2)
for idice in range(preds.shape[0]):
dice_scores += ((2 * (preds[idice] * yy[idice]).sum()) /
((preds[idice] + yy[idice]).sum() + 1e-8))
predss = np.logical_not(preds)
yyy = np.logical_not(yy)
for idice in range(preds.shape[0]):
dice_sc1 = ((2 * (preds[idice] * yy[idice]).sum()) /
((preds[idice] + yy[idice]).sum() + 1e-8))
dice_sc2 = ((2 * (predss[idice] * yyy[idice]).sum()) /
((predss[idice] + yyy[idice]).sum() + 1e-8))
dice_scores2 += (dice_sc1 + dice_sc2) / 2
epoch_dise = dice_scores/len(dataloader.dataset)
epoch_dise2 = dice_scores2/len(dataloader.dataset)
F1score = TP / (TP + ((1/2)*(FP+FN)) + 1e-8)
IoU = TP / (TP+FP+FN)
0
.for batch in tqdm(dataloader)
, the code iterates over all samples in the data set (or rather, over all samples accessible to the used DataLoader
, which might be a subset or an otherwise preprocessed version of the underlying data). This implies that the accumulated values represent the "global" results, i.e the results for the complete data set.model
to the sample data in the batch, x
, via model(x)
, the resulting predictions, outputs
, are thresholded to a binary representation, preds
, via segm(outputs) > 0.5
. The segm
function, in this case, is simply a sigmoid (see line 190 in the original code), which maps all values to the range [.0, .1]
. A similar step is performed for the "ground truth" (i.e. the true/known segmentation), y
, to produce its binary representation, yy
. [Update] The outputs
variable, in this context, holds the outputs of one of the two branches of the model only (compare relevant model code), thus either lesion or organ segmentation. Also compare Figure 2 in the paper: the corresponding branches are the ones with blue and green background. While suppressed in my pseudocode above (outputs, _ = model(x)
), the actual code also uses the output of the other branch, though only for loss calculation and not for the calculation of the scores/metrics. [/Update]preds
that holds the predicted segmentations for all samples in the current batch.yy
that holds the true/known segmentations for all samples in the current batch.sum()
calculations in the following lines are counting (albeit written in a bit of an unconventional way maybe) the number of voxels matching between predictions and true/known segmentations, over all samples in the current batch. For example, this means summing the number of matching voxels of predicted foreground and true/known background for the false positives, which is then added on top of their global count, FP
.dice_scores
. Dimension 0 usually indexes individual samples in the batch, so this would mean calculating a separate Dice score for each sample. The values are later normalized by dividing through the number of samples, len(dataloader.dataset)
, to gain epoch_dise
. These two steps are in accordance with the equation Dice score = … shared in the question, which calculates the Dice score separately for each sample, adds all corresponding results, then divides by the number of samples (called testing images there), N
.predss
and yyy
are the negations of preds
and yy
, respectively, i.e. True
for background voxels, False
for foreground voxels.dice_sc1
; but then also the Dice score for the background voxels is calculated as dice_sc2
. Then their average is taken.dice_scores2
, which is later normalized to epoch_dise2
, just as dice_scores
and epoch_dise
above.F1score = …
and IoU = …
over the global values of the true/false positives/negatives, in accordance with the corresponding equations cited in the question.So, to summarize once more:
dice_scores
, epoch_dise
) calculates what I would call the "standard" Dice coefficient, i.e. the score for overlapping foreground voxels.dice_scores2
, epoch_dise2
) calculates what I would call a "weighted" Dice coefficient: for each sample, it calculates both the score for overlapping foreground voxels and the score for overlapping background voxels, then averages them as the sample's score, and only then accumulates and averages again to get the global score.