I am trying to implement Quantization Aware Training(QAT) resnet18 model. While inferring I get this error
NotImplementedError: Could not run 'aten::add.out' with arguments from the 'QuantizedCPU' backend
I am trying to follow this documentation by pytorch for using their QAT API
Here is my code, I am also attaching a google collab notebook link
Block 1 - Importing the necessary libraries, defining training and evaluation functions
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
import matplotlib.pyplot as plt
import copy
import numpy as np
import os
def evaluate_model(model, test_loader, device, criterion=None):
model.eval()
model.to(device)
running_loss = 0
running_corrects = 0
for inputs, labels in test_loader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
if criterion is not None:
loss = criterion(outputs, labels).item()
else:
loss = 0
running_loss += loss * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
eval_loss = running_loss / len(test_loader.dataset)
eval_accuracy = running_corrects / len(test_loader.dataset)
return eval_loss, eval_accuracy
def train_model(model, train_loader, test_loader, device, learning_rate=1e-1, num_epochs=200):
criterion = nn.CrossEntropyLoss()
model.to(device)
#optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-4)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1, last_epoch=-1)
model.eval()
eval_loss, eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=device, criterion=criterion)
print("Epoch: {:02d} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(-1, eval_loss, eval_accuracy))
for epoch in range(num_epochs):
model.train()
running_loss = 0
running_corrects = 0
for inputs, labels in train_loader:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
if torch.isnan(loss):
print("NaN in Loss!")
return model
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
train_loss = running_loss / len(train_loader.dataset)
train_accuracy = running_corrects / len(train_loader.dataset)
model.eval()
eval_loss, eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=device, criterion=criterion)
scheduler.step()
print("Epoch: {:03d} Train Loss: {:.3f} Train Acc: {:.3f} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(epoch, train_loss, train_accuracy, eval_loss, eval_accuracy))
return model
Block 2 - Loading trainset and testset (CIFAR 100 resized to 224*224)
transform_train = transforms.Compose([
transforms.Resize(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transform_test = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
print("Data loaded and transformed successfully!")
Block 3 -
class QuantizedResNet18(nn.Module):
def __init__(self, model_fp32):
super().__init__()
# QuantStub converts tensors from floating point to quantized.
# This will only be used for inputs.
self.quant = torch.ao.quantization.QuantStub()
# DeQuantStub converts tensors from quantized to floating point.
# This will only be used for outputs.
self.dequant = torch.ao.quantization.DeQuantStub()
# FP32 model
self.model_fp32 = model_fp32
def forward(self, x):
# manually specify where tensors will be converted from floating
# point to quantized in the quantized model
print(f"Input shape before quant: {x.shape}, dtype: {x.dtype}")
x = self.quant(x)
print(f"Input shape after quant: {x.shape}, dtype: {x.dtype}")
x = self.model_fp32(x)
print(f"Input shape: {x.shape}, dtype: {x.dtype}")
# manually specify where tensors will be converted from quantized
# to floating point in the quantized model
x = self.dequant(x)
print(f"Input shape: {x.shape}, dtype: {x.dtype}")
return x
model = resnet18(num_classes=100, pretrained=False)
fused_model = copy.deepcopy(model)
fused_model.eval()
qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
fused_model.qconfig = qconfig
# Fuse the model in place rather manually.
fused_model = torch.ao.quantization.fuse_modules(fused_model, [["conv1", "bn1", "relu"]], inplace=True)
for module_name, module in fused_model.named_children():
if "layer" in module_name:
for basic_block_name, basic_block in module.named_children():
torch.ao.quantization.fuse_modules(basic_block, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], inplace=True)
for sub_block_name, sub_block in basic_block.named_children():
if sub_block_name == "downsample":
torch.ao.quantization.fuse_modules(sub_block, [["0", "1"]], inplace=True)
quantized_model_1 = QuantizedResNet18(model_fp32=fused_model)
quantized_model_1.qconfig = qconfig
cuda_device = torch.device("cuda:0")
quantized_model_1_prepared = torch.ao.quantization.prepare_qat(quantized_model_1.train())
trained_quantized_model_1_prepared = train_model(model=quantized_model_1_prepared, train_loader=trainloader, test_loader=testloader, device=cuda_device, learning_rate=1e-3, num_epochs=1)
cpu_device = torch.device("cpu:0")
trained_quantized_model_1_prepared.to(cpu_device)
trained_quantized_model_1_prepared.eval()
trained_quantized_model_1_prepared_int8 = torch.ao.quantization.convert(trained_quantized_model_1_prepared)
print(evaluate_model(model=trained_quantized_model_1_prepared_int8, test_loader=testloader, device=cpu_device))
the issue is in the last line when I try to run evaluate_model function, particularly while inferring (outputs = model(inputs))
I get the following error
NotImplementedError: Could not run 'aten::add.out' with arguments from the 'QuantizedCPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::add.out' is only available for these backends: [CPU, CUDA, Meta, MkldnnCPU, SparseCPU, SparseCUDA, SparseMeta, SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastXPU, AutocastMPS, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot,
This tutorial tells that for torch 2.0 this feature is beta and you need to adjust original model with at least one change (https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html#model-architecture) for residual addition:
Replacing addition with nn.quantized.FloatFunctional
You can see in your error trace that this line of code throws the error:
out += identity
So we need to:
from functools import partial
from typing import Any, Callable, List, Optional, Type, Union
import torch
import torch.nn as nn
from torch import Tensor
from torchvision.transforms._presets import ImageClassification
from torchvision.utils import _log_api_usage_once
from torchvision.models.resnet import Bottleneck
from torchvision.models.resnet import conv3x3
class BasicBlock(nn.Module):
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.skip_add.add(out, identity)
out = self.relu(out)
return out
Inject this by creating new constructor method for quantized model.
from torchvision.models.resnet import ResNet18_Weights, _resnet
from torchvision.models._api import register_model, Weights, WeightsEnum
from torchvision.models._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
@register_model()
@handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1))
def quantizedresnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any):
weights = ResNet18_Weights.verify(weights)
return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
Go back to the cell with
model = resnet18(num_classes=100, pretrained=False)
fused_model = copy.deepcopy(model)
fused_model.eval()
qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
fused_model.qconfig = qconfig
And change it with this (see the first line is changed)
model = quantizedresnet18(num_classes=100, pretrained=False)
fused_model = copy.deepcopy(model)
fused_model.eval()
qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
fused_model.qconfig = qconfig
Then execute all cells below again.