I am trying to export pretrained Mask R-CNN model to ONNX format. Since this model in basic configuration has following structure (here I added batch_size
as dynamic axes):
I want to customize my model and add batch_size
to output (it means I need to add new dim to each of the outputs).
I wrote following code to make it possible:
class MaskRCNNModel(torch.nn.Module):
def __init__(self):
super(MaskRCNNModel, self).__init__()
self.model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights='DEFAULT')
in_features = self.model.roi_heads.box_predictor.cls_score.in_features
self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=7)
self.model.load_state_dict(torch.load("saved_dict.torch"))
def forward(self, input):
outputs = self.model.forward(input)
boxes = []
labels = []
scores = []
masks = []
for result in outputs:
box, label, score, mask = result.values()
boxes.append(box)
labels.append(label)
scores.append(score)
masks.append(mask)
return boxes, labels, scores, masks
maskrcnn_model = MaskRCNNModel()
maskrcnn_model.eval()
maskrcnn_model.to(device)
x = torch.rand(1, 3, 512, 512)
x = x.to(device)
maskrcnn_model(x)
torch.onnx.export(maskrcnn_model,
x,
"base_model_100_epochs.onnx",
opset_version=11,
input_names=["input"],
output_names=["boxes", "labels", "scores", "masks"])
but the code above doesn't change any export parameters. The structure of output stays the same:
What should I do to customize forward
method to be able to add batch_size
into ONNX model output?
As per my original comment, I would discourage deploying most torchvision
models with ONNX. It is all around a great module, just that it was not originally written with the intention to go well with static graphs.
If throughput is a consideration, this implementation Mask R-CNN is not the way to go. With earlier ONNX opsets, I've had this model spend most of its execution time for h2d/d2h operations when falling back to CPU. I recommend checking YOLOv8 by ultralytics for a newer take on instance segmentation, or some of the many static implementations found on github.
The model is designed with user-friendliness in mind, so for each image in the input batch it outputs a dictionary of tensors with accepted and post-processed results. For example if you have two images with ten detected objects in the first image and three in the second, the output would be
batch = torch.randn((2, 3, 256, 256)) # Input two images
output = mask_rcnn(batch) # run model
results1, results2 = output # One dictionary per batch
for key, value in results1:
print(key, value.shape)
>>> boxes [10, 4]
>>> labels [10]
>>> scores [10]
>>> masks [10, 1, 256, 256]
for key, value in results2:
print(key, value.shape)
>>> boxes [3, 4]
>>> labels [3]
>>> scores [3]
>>> masks [3, 1, 256, 256]
Is because ONNX does not understand python
types. During torch.onnx.export
, lists, dictionaries, tuples, etc. have no special meaning, and their entries are saved either as tensors or as constants. So the only thing your custom forward pass does is changes the order of the outputs, e.g. with the previous example outputs transform from
>>> boxes1 [10, 4]
>>> labels1 [10]
>>> scores1 [10]
>>> masks1 [10, 1, 256, 256]
>>> boxes2 [3, 4]
>>> labels2 [3]
>>> scores2 [3]
>>> masks2 [3, 1, 256, 256]
to
>>> boxes1 [10, 4]
>>> boxes2 [3, 4]
>>> labels1 [10]
>>> labels2 [3]
>>> scores1 [10]
>>> scores2 [3]
>>> masks1 [10, 1, 256, 256]
>>> masks2 [3, 1, 256, 256]
Torch ONNX documentation is worth reading as to how python
and torch
types are interpreted during export.
Is to have the model output batched results. E.g. you want the model to output tensors
boxes [batch_size, num_detections, 4]
labels [batch_size, num_detections]
scores [batch_size, num_detections]
masks [batch_size, num_detections, 1, 256, 256]
We immediately see that this is impossible without applying any tricks. As different images in the batch will have a varying amount of predicted objects, we cannot create a tensor with 10
bounding boxes in the first index and 4
in the second.
To output batched results in this scenario, you can define constant shaped output tensors, and paste results for each image into them. For instance
def forward(self, input):
# Maximum number of detections the vision model will output per batch
max_detections = self.model.roi_heads.detections_per_img
# Variables for output tensor shapes
# Use tensor.size instead of tensor.shape for dynamic inputs
batch_size, _, input_height, input_width = input.shape
# Create batched output tensors
all_boxes = torch.zeros((batch_size, max_detections, 4))
all_labels = torch.zeros((batch_size, max_detections))
all_scores = torch.zeros((batch_size, max_detections))
# Masks are returned with a redundant channel in the second dimension
all_masks = torch.zeros((batch_size, max_detections, 1, input_height, input_width))
# Number of detections per batch
detections_per_batch = torch.zeros((batch_size, 1))
# Run forward pass
outputs = self.model.forward(input)
for idx, result in enumerate(outputs):
boxes, labels, scores, masks = result.values()
# Number of detections for batch
n_dets = boxes.size(0)
detections_per_batch[idx] = n_dets
# Paste batch results into output tensors
all_boxes[idx, : n_dets] = boxes
all_labels[idx, : n_dets] = labels
all_scores[idx, : n_dets] = scores
all_masks[idx, : n_dets] = masks
return detections_per_batch, all_boxes, all_labels, all_scores, all_masks
This forward pass creates output tensors which can potentially hold
all object detections, and copies the realized object detections for each batch in to them. To keep track of which entries are zero-padding and which are actual detections, a tensor detections_per_batch
is returned on top of the Mask R-CNN outputs. This is then used to extract the real predictions from ONNX outputs
for preds, boxes, labels, scores, masks in zip(*outputs):
detected_boxes = boxes[: preds]
detected_labels = labels[: preds]
...
This will have problems with I/O or memory bound applications, as the model always returns outputs with space for all potential detected masks. If you have a good upper bound for the amount of objects, you can limit this by reducing model.roi_heads.detections_per_img
.