Is there a way that automatically generates a confusion matrix using the same test set and its annotations ?
This is the link to RetinaNet model
I tried adapting a confusion matrix from an image classification model which obviously didn't work, I'm currently trying to use the results from running inference.py and dataset.py since the first prodiuces the prediction results and the second shows a few examples of the dataset used for training.
I tried something bizarre but it does the job (I guess) After getting 2 seperate lists : one for the ground truth and the second for predictions, I Count occurrences in both lists and consider each intersection as a true positive and any other class intersection as a false positive.
This is the code, I know it's not optimised and I used about 60% of the original RetinaNet code to parse the data, but it works. I hope it helps some of you :
import torch
import cv2
import numpy as np
import os
import glob as glob
import argparse
import time
from collections import Counter
import matplotlib.pyplot as plt
from collections import Counter
from model import create_model
from config import (
NUM_CLASSES, DEVICE, CLASSES
)
from xml.etree import ElementTree as et
from config import (
CLASSES, DATASET_IMAGE_WIDTH, DATASET_IMAGE_HEIGHT
)
from torch.utils.data import Dataset
DIR_TEST = 'data/test'
test_images = glob.glob(f"{DIR_TEST}/*.png")
print(f"Test instances: {len(test_images)}")
# The dataset class.
class CustomDataset(Dataset):
def __init__(self, dir_path, width, height, classes, transforms=None):
self.transforms = transforms
self.dir_path = dir_path
self.height = height
self.width = width
self.classes = classes
self.image_file_types = ['*.jpg', '*.jpeg', '*.png', '*.ppm', '*.JPG']
self.all_image_paths = []
# Get all the image paths in sorted order.
for file_type in self.image_file_types:
self.all_image_paths.extend(glob.glob(os.path.join(self.dir_path, file_type)))
self.all_images = [image_path.split(os.path.sep)[-1] for image_path in self.all_image_paths]
self.all_images = sorted(self.all_images)
def __getitem__(self, idx):
# Capture the image name and the full image path.
image_name = self.all_images[idx]
image_path = os.path.join(self.dir_path, image_name)
# Read and preprocess the image.
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
image_resized = cv2.resize(image, (self.width, self.height))
image_resized /= 255.0
# Capture the corresponding XML file for getting the annotations.
annot_filename = os.path.splitext(image_name)[0] + '.xml'
annot_file_path = os.path.join(self.dir_path, annot_filename)
boxes = []
labels = []
tree = et.parse(annot_file_path)
root = tree.getroot()
# Original image width and height.
image_width = image.shape[1]
image_height = image.shape[0]
# Box coordinates for xml files are extracted
# and corrected for image size given.
for member in root.findall('object'):
# Get label and map the `classes`.
labels.append(self.classes.index(member.find('name').text))
# Left corner x-coordinates.
xmin = int(member.find('bndbox').find('xmin').text)
# Right corner x-coordinates.
xmax = int(member.find('bndbox').find('xmax').text)
# Left corner y-coordinates.
ymin = int(member.find('bndbox').find('ymin').text)
# Right corner y-coordinates.
ymax = int(member.find('bndbox').find('ymax').text)
# Resize the bounding boxes according
# to resized image `width`, `height`.
xmin_final = (xmin/image_width)*self.width
xmax_final = (xmax/image_width)*self.width
ymin_final = (ymin/image_height)*self.height
ymax_final = (ymax/image_height)*self.height
# Check that max coordinates are at least one pixel
# larger than min coordinates.
if xmax_final == xmin_final:
xmax_final += 1
if ymax_final == ymin_final:
ymax_final += 1
# Check that all coordinates are within the image.
if xmax_final > self.width:
xmax_final = self.width
if ymax_final > self.height:
ymax_final = self.height
boxes.append([xmin_final, ymin_final, xmax_final, ymax_final])
# Bounding box to tensor.
boxes = torch.as_tensor(boxes, dtype=torch.float32)
# Area of the bounding boxes.
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) if len(boxes) > 0 \
else torch.as_tensor(boxes, dtype=torch.float32)
# No crowd instances.
iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
# Labels to tensor.
labels = torch.as_tensor(labels, dtype=torch.int64)
# Prepare the final `target` dictionary.
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["area"] = area
target["iscrowd"] = iscrowd
image_id = torch.tensor([idx])
target["image_id"] = image_id
# Apply the image transforms.
if self.transforms:
sample = self.transforms(image = image_resized,
bboxes = target['boxes'],
labels = labels)
image_resized = sample['image']
target['boxes'] = torch.Tensor(sample['bboxes'])
if np.isnan((target['boxes']).numpy()).any() or target['boxes'].shape == torch.Size([0]):
target['boxes'] = torch.zeros((0, 4), dtype=torch.int64)
return image_resized, target
def __len__(self):
return len(self.all_images)
GT = []
# USAGE: python datasets.py
if __name__ == '__main__':
# sanity check of the Dataset pipeline with sample visualization
dataset = CustomDataset(
DIR_TEST, DATASET_IMAGE_WIDTH, DATASET_IMAGE_HEIGHT, CLASSES
)
# function to visualize a single sample
def visualize_sample(image, target):
for box_num in range(len(target['boxes'])):
box = target['boxes'][box_num]
label = CLASSES[target['labels'][box_num]]
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
cv2.rectangle(
image,
(int(box[0]), int(box[1])), (int(box[2]), int(box[3])),
(0, 0, 255),
2
)
cv2.putText(
image,
label,
(int(box[0]), int(box[1]-5)),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(0, 0, 255),
2
)
cv2.imshow('Image', image)
cv2.waitKey(0)
for i in range(len(test_images)):
image, target = dataset[i]
GT.append(target['labels'].tolist())
#inference code
np.random.seed(42)
# Construct the argument parser.
parser = argparse.ArgumentParser()
parser.add_argument(
'-i', '--input',
help='path to input image directory',
)
parser.add_argument(
'--imgsz',
default=None,
type=int,
help='image resize shape'
)
parser.add_argument(
'--threshold',
default=0.25,
type=float,
help='detection threshold'
)
args = vars(parser.parse_args())
os.makedirs('inference_outputs/images', exist_ok=True)
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))
# Load the best model and trained weights.
model = create_model(num_classes=NUM_CLASSES)
checkpoint = torch.load('outputs/best_model.pth', map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(DEVICE).eval()
MP = []
LMP = []
for i in range(len(test_images)):
MP.append([])
# Get the image file name for saving output later on.
image_name = test_images[i].split(os.path.sep)[-1].split('.')[0]
image = cv2.imread(test_images[i])
orig_image = image.copy()
if args['imgsz'] is not None:
image = cv2.resize(image, (args['imgsz'], args['imgsz']))
# BGR to RGB.
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
# Make the pixel range between 0 and 1.
image /= 255.0
# Bring color channels to front (H, W, C) => (C, H, W).
image_input = np.transpose(image, (2, 0, 1)).astype(np.float32)
# Convert to tensor.
image_input = torch.tensor(image_input, dtype=torch.float).to('cpu')
# Add batch dimension.
image_input = torch.unsqueeze(image_input, 0)
start_time = time.time()
# Predictions
with torch.no_grad():
outputs = model(image_input.to(DEVICE))
end_time = time.time()
# Load all detection to CPU for further operations.
outputs = [{k: v.to('cpu') for k, v in t.items()} for t in outputs]
# Carry further only if there are detected boxes.
if len(outputs[0]['boxes']) != 0:
boxes = outputs[0]['boxes'].data.numpy()
scores = outputs[0]['scores'].data.numpy()
# Filter out boxes according to `detection_threshold`.
boxes = boxes[scores >= args['threshold']].astype(np.int32)
draw_boxes = boxes.copy()
# Get all the predicited class names.
pred_classes = [CLASSES[i] for i in outputs[0]['labels'].cpu().numpy()]
# Draw the bounding boxes and write the class name on top of it.
for j, box in enumerate(draw_boxes):
class_name = pred_classes[j]
MP[i].append(class_name)
color = COLORS[CLASSES.index(class_name)]
# Recale boxes.
xmin = int((box[0] / image.shape[1]) * orig_image.shape[1])
ymin = int((box[1] / image.shape[0]) * orig_image.shape[0])
xmax = int((box[2] / image.shape[1]) * orig_image.shape[1])
ymax = int((box[3] / image.shape[0]) * orig_image.shape[0])
cv2.rectangle(orig_image,
(xmin, ymin),
(xmax, ymax),
color[::-1],
3)
cv2.putText(orig_image,
class_name,
(xmin, ymin-5),
cv2.FONT_HERSHEY_SIMPLEX,
0.8,
color[::-1],
2,
lineType=cv2.LINE_AA)
LMP.append(MP[i])
# Mapping
mapping = {
1: 'bluetooth',
2: 'wifi',
3: 'drone'
}
# Example ground truth and predictions
ground_truth = [[mapping[num] for num in sublist] for sublist in GT]
predictions = LMP
print("Ground truth list:", ground_truth)
print("All predicted classes lists:", predictions)
# Convert ground truth and predictions to tuples of tuples
ground_truth = tuple(tuple(sublist) for sublist in ground_truth)
predictions = tuple(tuple(sublist) for sublist in predictions)
# Initialize confusion matrices
confusion_matrices = {
'bluetooth': np.zeros((2, 2), dtype=int),
'wifi': np.zeros((2, 2), dtype=int),
'drone': np.zeros((2, 2), dtype=int)
}
# Count occurrences of each class in ground truth and predictions
gt_counts = Counter(ground_truth)
pred_counts = Counter(predictions)
gt_counts = Counter()
pred_counts = Counter()
for gt_list in ground_truth:
gt_counts.update(gt_list)
for pred_list in predictions:
pred_counts.update(pred_list)
# Populate confusion matrices
for class_idx, class_n in mapping.items():
TP = min(gt_counts[class_n], pred_counts[class_n])
FP = max(0, pred_counts[class_n] - gt_counts[class_n])
FN = max(0, gt_counts[class_n] - pred_counts[class_n])
# TN is tricky to calculate correctly in a multi-class setup but simplified here
TN = len(ground_truth) + len(predictions) - (TP + FP + FN)
confusion_matrices[class_n][0, 0] = TP
confusion_matrices[class_n][0, 1] = FP
confusion_matrices[class_n][1, 0] = FN
confusion_matrices[class_n][1, 1] = TN
# Print the confusion matrices
for class_n, matrix in confusion_matrices.items():
print(f"Confusion Matrix for {class_n}:")
print(matrix)
print("TP: ", matrix[0, 0])
print("TN: ", matrix[1, 1])
print("FP: ", matrix[0, 1])
print("FN: ", matrix[1, 0])
print()