pythonpytorchobject-detectionconfusion-matrixretinanet

Generate a confusion matrix for RetinaNet object detection model


Is there a way that automatically generates a confusion matrix using the same test set and its annotations ?

This is the link to RetinaNet model

I tried adapting a confusion matrix from an image classification model which obviously didn't work, I'm currently trying to use the results from running inference.py and dataset.py since the first prodiuces the prediction results and the second shows a few examples of the dataset used for training.


Solution

  • I tried something bizarre but it does the job (I guess) After getting 2 seperate lists : one for the ground truth and the second for predictions, I Count occurrences in both lists and consider each intersection as a true positive and any other class intersection as a false positive.

    This is the code, I know it's not optimised and I used about 60% of the original RetinaNet code to parse the data, but it works. I hope it helps some of you :

    import torch
    import cv2
    import numpy as np
    import os
    import glob as glob
    import argparse
    import time
    from collections import Counter
    import matplotlib.pyplot as plt
    from collections import Counter
    
    from model import create_model
    
    from config import (
        NUM_CLASSES, DEVICE, CLASSES
    )
    
    from xml.etree import ElementTree as et
    from config import (
        CLASSES, DATASET_IMAGE_WIDTH, DATASET_IMAGE_HEIGHT
    )
    from torch.utils.data import Dataset
    
    DIR_TEST = 'data/test'
    test_images = glob.glob(f"{DIR_TEST}/*.png")
    print(f"Test instances: {len(test_images)}")
    
    # The dataset class.
    class CustomDataset(Dataset):
        def __init__(self, dir_path, width, height, classes, transforms=None):
            self.transforms = transforms
            self.dir_path = dir_path
            self.height = height
            self.width = width
            self.classes = classes
            self.image_file_types = ['*.jpg', '*.jpeg', '*.png', '*.ppm', '*.JPG']
            self.all_image_paths = []
            
            # Get all the image paths in sorted order.
            for file_type in self.image_file_types:
                self.all_image_paths.extend(glob.glob(os.path.join(self.dir_path, file_type)))
            self.all_images = [image_path.split(os.path.sep)[-1] for image_path in self.all_image_paths]
            self.all_images = sorted(self.all_images)
    
        def __getitem__(self, idx):
            # Capture the image name and the full image path.
            image_name = self.all_images[idx]
            image_path = os.path.join(self.dir_path, image_name)
    
            # Read and preprocess the image.
            image = cv2.imread(image_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
            image_resized = cv2.resize(image, (self.width, self.height))
            image_resized /= 255.0
            
            # Capture the corresponding XML file for getting the annotations.
            annot_filename = os.path.splitext(image_name)[0] + '.xml'
            annot_file_path = os.path.join(self.dir_path, annot_filename)
            
            boxes = []
            labels = []
            tree = et.parse(annot_file_path)
            root = tree.getroot()
            
            # Original image width and height.
            image_width = image.shape[1]
            image_height = image.shape[0]
            
            # Box coordinates for xml files are extracted 
            # and corrected for image size given.
            for member in root.findall('object'):
                # Get label and map the `classes`.
                labels.append(self.classes.index(member.find('name').text))
                
                # Left corner x-coordinates.
                xmin = int(member.find('bndbox').find('xmin').text)
                # Right corner x-coordinates.
                xmax = int(member.find('bndbox').find('xmax').text)
                # Left corner y-coordinates.
                ymin = int(member.find('bndbox').find('ymin').text)
                # Right corner y-coordinates.
                ymax = int(member.find('bndbox').find('ymax').text)
                
                # Resize the bounding boxes according 
                # to resized image `width`, `height`.
                xmin_final = (xmin/image_width)*self.width
                xmax_final = (xmax/image_width)*self.width
                ymin_final = (ymin/image_height)*self.height
                ymax_final = (ymax/image_height)*self.height
    
                # Check that max coordinates are at least one pixel
                # larger than min coordinates.
                if xmax_final == xmin_final:
                    xmax_final += 1
                if ymax_final == ymin_final:
                    ymax_final += 1
                # Check that all coordinates are within the image.
                if xmax_final > self.width:
                    xmax_final = self.width
                if ymax_final > self.height:
                    ymax_final = self.height
                
                boxes.append([xmin_final, ymin_final, xmax_final, ymax_final])
            
            # Bounding box to tensor.
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            # Area of the bounding boxes.
            area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) if len(boxes) > 0 \
                else torch.as_tensor(boxes, dtype=torch.float32)
            # No crowd instances.
            iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
            # Labels to tensor.
            labels = torch.as_tensor(labels, dtype=torch.int64)
            # Prepare the final `target` dictionary.
            target = {}
            target["boxes"] = boxes
            target["labels"] = labels
            target["area"] = area
            target["iscrowd"] = iscrowd
            image_id = torch.tensor([idx])
            target["image_id"] = image_id
            # Apply the image transforms.
            if self.transforms:
                sample = self.transforms(image = image_resized,
                                         bboxes = target['boxes'],
                                         labels = labels)
                image_resized = sample['image']
                target['boxes'] = torch.Tensor(sample['bboxes'])
            
            if np.isnan((target['boxes']).numpy()).any() or target['boxes'].shape == torch.Size([0]):
                target['boxes'] = torch.zeros((0, 4), dtype=torch.int64)
            return image_resized, target
    
        def __len__(self):
            return len(self.all_images)
    
    
    GT = []
    
    
    # USAGE: python datasets.py
    if __name__ == '__main__':
        # sanity check of the Dataset pipeline with sample visualization
        dataset = CustomDataset(
            DIR_TEST, DATASET_IMAGE_WIDTH, DATASET_IMAGE_HEIGHT, CLASSES
        )
        
        # function to visualize a single sample
        def visualize_sample(image, target):
            for box_num in range(len(target['boxes'])):
                box = target['boxes'][box_num]
                label = CLASSES[target['labels'][box_num]]
                image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
                cv2.rectangle(
                    image, 
                    (int(box[0]), int(box[1])), (int(box[2]), int(box[3])),
                    (0, 0, 255), 
                    2
                )
                cv2.putText(
                    image, 
                    label,
                    (int(box[0]), int(box[1]-5)), 
                    cv2.FONT_HERSHEY_SIMPLEX, 
                    0.7, 
                    (0, 0, 255), 
                    2
                )
            cv2.imshow('Image', image)
            cv2.waitKey(0)
            
        for i in range(len(test_images)):
            image, target = dataset[i]
            GT.append(target['labels'].tolist())
    
    
    
    
    
    #inference code
    np.random.seed(42)
    
    # Construct the argument parser.
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-i', '--input', 
        help='path to input image directory',
    )
    parser.add_argument(
        '--imgsz', 
        default=None,
        type=int,
        help='image resize shape'
    )
    parser.add_argument(
        '--threshold',
        default=0.25,
        type=float,
        help='detection threshold'
    )
    args = vars(parser.parse_args())
    
    os.makedirs('inference_outputs/images', exist_ok=True)
    
    COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))
    
    # Load the best model and trained weights.
    model = create_model(num_classes=NUM_CLASSES)
    checkpoint = torch.load('outputs/best_model.pth', map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(DEVICE).eval()
    
    MP = []
    LMP = []
    
    for i in range(len(test_images)):
        MP.append([])
        # Get the image file name for saving output later on.
        image_name = test_images[i].split(os.path.sep)[-1].split('.')[0]
        image = cv2.imread(test_images[i])
        orig_image = image.copy()
        if args['imgsz'] is not None:
            image = cv2.resize(image, (args['imgsz'], args['imgsz']))
        # BGR to RGB.
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        # Make the pixel range between 0 and 1.
        image /= 255.0
        # Bring color channels to front (H, W, C) => (C, H, W).
        image_input = np.transpose(image, (2, 0, 1)).astype(np.float32)
        # Convert to tensor.
        image_input = torch.tensor(image_input, dtype=torch.float).to('cpu')
        # Add batch dimension.
        image_input = torch.unsqueeze(image_input, 0)
        start_time = time.time()
        # Predictions
        with torch.no_grad():
            outputs = model(image_input.to(DEVICE))
        end_time = time.time()
    
    
        # Load all detection to CPU for further operations.
        outputs = [{k: v.to('cpu') for k, v in t.items()} for t in outputs]
        # Carry further only if there are detected boxes.
        if len(outputs[0]['boxes']) != 0:
            boxes = outputs[0]['boxes'].data.numpy()
            scores = outputs[0]['scores'].data.numpy()
            # Filter out boxes according to `detection_threshold`.
            boxes = boxes[scores >= args['threshold']].astype(np.int32)
            draw_boxes = boxes.copy()
            # Get all the predicited class names.
            pred_classes = [CLASSES[i] for i in outputs[0]['labels'].cpu().numpy()]
            # Draw the bounding boxes and write the class name on top of it.
            for j, box in enumerate(draw_boxes):
                class_name = pred_classes[j]
                MP[i].append(class_name)
                color = COLORS[CLASSES.index(class_name)]
                # Recale boxes.
                xmin = int((box[0] / image.shape[1]) * orig_image.shape[1])
                ymin = int((box[1] / image.shape[0]) * orig_image.shape[0])
                xmax = int((box[2] / image.shape[1]) * orig_image.shape[1])
                ymax = int((box[3] / image.shape[0]) * orig_image.shape[0])
                cv2.rectangle(orig_image,
                            (xmin, ymin),
                            (xmax, ymax),
                            color[::-1], 
                            3)
                cv2.putText(orig_image, 
                            class_name, 
                            (xmin, ymin-5),
                            cv2.FONT_HERSHEY_SIMPLEX, 
                            0.8, 
                            color[::-1], 
                            2, 
                            lineType=cv2.LINE_AA)    
            
            LMP.append(MP[i])
        
    
    
    # Mapping
    mapping = {
        1: 'bluetooth',
        2: 'wifi',
        3: 'drone'
    }
    
    # Example ground truth and predictions
    ground_truth = [[mapping[num] for num in sublist] for sublist in GT]
    predictions = LMP
    
    print("Ground truth list:", ground_truth)
    print("All predicted classes lists:", predictions)
    
    # Convert ground truth and predictions to tuples of tuples
    ground_truth = tuple(tuple(sublist) for sublist in ground_truth)
    predictions = tuple(tuple(sublist) for sublist in predictions)
    
    # Initialize confusion matrices
    confusion_matrices = {
        'bluetooth': np.zeros((2, 2), dtype=int),
        'wifi': np.zeros((2, 2), dtype=int),
        'drone': np.zeros((2, 2), dtype=int)
    }
    
    # Count occurrences of each class in ground truth and predictions
    gt_counts = Counter(ground_truth)
    pred_counts = Counter(predictions)
    
    gt_counts = Counter()
    pred_counts = Counter()
    
    for gt_list in ground_truth:
        gt_counts.update(gt_list)
    
    for pred_list in predictions:
        pred_counts.update(pred_list)
    
    # Populate confusion matrices
    for class_idx, class_n in mapping.items():
        TP = min(gt_counts[class_n], pred_counts[class_n])
        FP = max(0, pred_counts[class_n] - gt_counts[class_n])
        FN = max(0, gt_counts[class_n] - pred_counts[class_n])
        
        # TN is tricky to calculate correctly in a multi-class setup but simplified here
        TN = len(ground_truth) + len(predictions) - (TP + FP + FN)
        
        confusion_matrices[class_n][0, 0] = TP
        confusion_matrices[class_n][0, 1] = FP
        confusion_matrices[class_n][1, 0] = FN
        confusion_matrices[class_n][1, 1] = TN
    
    # Print the confusion matrices
    for class_n, matrix in confusion_matrices.items():
        print(f"Confusion Matrix for {class_n}:")
        print(matrix)
        print("TP: ", matrix[0, 0])
        print("TN: ", matrix[1, 1])
        print("FP: ", matrix[0, 1])
        print("FN: ", matrix[1, 0])
        print()