pytorchtorchvisionflops

How to compute the model complexity of FasterRCNNFPN pretrained from torchvision?


I got the pretrained FASTERRCNN_RESNET50_FPN model from pytorch (torchvision), here's the link.

Now I want to compute the model's complexity (number of parameters and FLOPs) as reported from torchvsion: enter image description here

How to do this? Normally with the classification model (e.g. resnet50), we can use tools such as thop or ptflop. But the main concern is: What is the correct input image size (width & height, channel=3 for sure)? From my reading, FasterCNN accepts unfixed input image size, but I've not found the step where the image is resized during forward. Personally, I think the image will be passed to the backbone firstly (which is resnet50), so I chose input image size = (224,224) (same as imagenet's). But when trying this with ptflop, the output FLOPs is very unstable.

Any recommendation is appreciated! Thanks in advance

I tried with ptflops. I expect a reasonable answer on the correct input image size.


Solution

  • The number of parameters is invariant with input size. For the computational complexity (e.g, MACs), torchvision set default value for their models as follows:

    detection_models_input_dims = {
        "fasterrcnn_mobilenet_v3_large_320_fpn": (320, 320),
        "fasterrcnn_mobilenet_v3_large_fpn": (800, 800),
        "fasterrcnn_resnet50_fpn": (800, 800),
        "fasterrcnn_resnet50_fpn_v2": (800, 800),
        "fcos_resnet50_fpn": (800, 800),
        "keypointrcnn_resnet50_fpn": (1333, 1333),
        "maskrcnn_resnet50_fpn": (800, 800),
        "maskrcnn_resnet50_fpn_v2": (800, 800),
        "retinanet_resnet50_fpn": (800, 800),
        "retinanet_resnet50_fpn_v2": (800, 800),
        "ssd300_vgg16": (300, 300),
        "ssdlite320_mobilenet_v3_large": (320, 320),
    }
    

    This can be verified here: https://github.com/pytorch/vision/blob/25c8a3a2cc2699e4e261b9e0777a6dc5badb5f9f/test/test_extended_models.py#L158

    More discussion https://github.com/pytorch/vision/pull/6936

    Running ptflop with an input size of (800, 800) yields identical numbers as reported by torchvision.