python-3.xdeep-learningpytorchresnetquantization-aware-training

"NotImplementedError: Could not run 'aten::add.out' with arguments from the 'QuantizedCPU' backend" while implementing QAT on resnet18 using pytorch


I am trying to implement Quantization Aware Training(QAT) resnet18 model. While inferring I get this error

NotImplementedError: Could not run 'aten::add.out' with arguments from the 'QuantizedCPU' backend

I am trying to follow this documentation by pytorch for using their QAT API

Here is my code, I am also attaching a google collab notebook link

Block 1 - Importing the necessary libraries, defining training and evaluation functions

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
import matplotlib.pyplot as plt
import copy
import numpy as np
import os

def evaluate_model(model, test_loader, device, criterion=None):
    model.eval()
    model.to(device)
    running_loss = 0
    running_corrects = 0
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        if criterion is not None:
            loss = criterion(outputs, labels).item()
        else:
            loss = 0
        running_loss += loss * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
    eval_loss = running_loss / len(test_loader.dataset)
    eval_accuracy = running_corrects / len(test_loader.dataset)
    return eval_loss, eval_accuracy



def train_model(model, train_loader, test_loader, device, learning_rate=1e-1, num_epochs=200):
    criterion = nn.CrossEntropyLoss()
    model.to(device)
    #optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-4)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1, last_epoch=-1)
    model.eval()
    eval_loss, eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=device, criterion=criterion)
    print("Epoch: {:02d} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(-1, eval_loss, eval_accuracy))
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0
        running_corrects = 0
        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            if torch.isnan(loss):
                print("NaN in Loss!")
                return model
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        train_loss = running_loss / len(train_loader.dataset)
        train_accuracy = running_corrects / len(train_loader.dataset)
        model.eval()
        eval_loss, eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=device, criterion=criterion)
        scheduler.step()
        print("Epoch: {:03d} Train Loss: {:.3f} Train Acc: {:.3f} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(epoch, train_loss, train_accuracy, eval_loss, eval_accuracy))
    return model

Block 2 - Loading trainset and testset (CIFAR 100 resized to 224*224)

transform_train = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transform_test = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

print("Data loaded and transformed successfully!")

Block 3 -

class QuantizedResNet18(nn.Module):
    def __init__(self, model_fp32):
        super().__init__()
        # QuantStub converts tensors from floating point to quantized.
        # This will only be used for inputs.
        self.quant = torch.ao.quantization.QuantStub()
        # DeQuantStub converts tensors from quantized to floating point.
        # This will only be used for outputs.
        self.dequant = torch.ao.quantization.DeQuantStub()
        # FP32 model
        self.model_fp32 = model_fp32

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        print(f"Input shape before quant: {x.shape}, dtype: {x.dtype}")
        x = self.quant(x)
        print(f"Input shape after quant: {x.shape}, dtype: {x.dtype}")
        x = self.model_fp32(x)
        print(f"Input shape: {x.shape}, dtype: {x.dtype}")
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        print(f"Input shape: {x.shape}, dtype: {x.dtype}")
        return x

model = resnet18(num_classes=100, pretrained=False)
fused_model = copy.deepcopy(model)
fused_model.eval()
qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
fused_model.qconfig = qconfig

# Fuse the model in place rather manually.
fused_model = torch.ao.quantization.fuse_modules(fused_model, [["conv1", "bn1", "relu"]], inplace=True)
for module_name, module in fused_model.named_children():
    if "layer" in module_name:
        for basic_block_name, basic_block in module.named_children():
            torch.ao.quantization.fuse_modules(basic_block, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], inplace=True)
            for sub_block_name, sub_block in basic_block.named_children():
                if sub_block_name == "downsample":
                    torch.ao.quantization.fuse_modules(sub_block, [["0", "1"]], inplace=True)

quantized_model_1 = QuantizedResNet18(model_fp32=fused_model)
quantized_model_1.qconfig = qconfig
cuda_device = torch.device("cuda:0")
quantized_model_1_prepared = torch.ao.quantization.prepare_qat(quantized_model_1.train())
trained_quantized_model_1_prepared = train_model(model=quantized_model_1_prepared, train_loader=trainloader, test_loader=testloader, device=cuda_device, learning_rate=1e-3, num_epochs=1)
cpu_device = torch.device("cpu:0")
trained_quantized_model_1_prepared.to(cpu_device)
trained_quantized_model_1_prepared.eval()
trained_quantized_model_1_prepared_int8 =  torch.ao.quantization.convert(trained_quantized_model_1_prepared)
print(evaluate_model(model=trained_quantized_model_1_prepared_int8, test_loader=testloader, device=cpu_device))

the issue is in the last line when I try to run evaluate_model function, particularly while inferring (outputs = model(inputs))

I get the following error

NotImplementedError: Could not run 'aten::add.out' with arguments from the 'QuantizedCPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::add.out' is only available for these backends: [CPU, CUDA, Meta, MkldnnCPU, SparseCPU, SparseCUDA, SparseMeta, SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastXPU, AutocastMPS, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, 

Solution

  • This tutorial tells that for torch 2.0 this feature is beta and you need to adjust original model with at least one change (https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html#model-architecture) for residual addition:

    Replacing addition with nn.quantized.FloatFunctional

    You can see in your error trace that this line of code throws the error:

            out += identity
    
    

    https://github.com/pytorch/vision/blob/229d8523bfa9a2696872d76b1cdb6815028f1e03/torchvision/models/resnet.py#L102

    So we need to:

    1. reimplement BasicBlock by replacing += operator with skip_add;
    2. inject BasicBlock to the Resnet constructor.

    Step 1

    from functools import partial
    from typing import Any, Callable, List, Optional, Type, Union
    
    import torch
    import torch.nn as nn
    from torch import Tensor
    
    from torchvision.transforms._presets import ImageClassification
    from torchvision.utils import _log_api_usage_once
    from torchvision.models.resnet import Bottleneck
    
    
    from torchvision.models.resnet import conv3x3
    
    class BasicBlock(nn.Module):
        expansion: int = 1
    
        def __init__(
            self,
            inplanes: int,
            planes: int,
            stride: int = 1,
            downsample: Optional[nn.Module] = None,
            groups: int = 1,
            base_width: int = 64,
            dilation: int = 1,
            norm_layer: Optional[Callable[..., nn.Module]] = None,
        ) -> None:
            super().__init__()
            if norm_layer is None:
                norm_layer = nn.BatchNorm2d
            if groups != 1 or base_width != 64:
                raise ValueError("BasicBlock only supports groups=1 and base_width=64")
            if dilation > 1:
                raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
            # Both self.conv1 and self.downsample layers downsample the input when stride != 1
            self.conv1 = conv3x3(inplanes, planes, stride)
            self.bn1 = norm_layer(planes)
            self.relu = nn.ReLU(inplace=True)
            self.conv2 = conv3x3(planes, planes)
            self.bn2 = norm_layer(planes)
            self.downsample = downsample
            self.stride = stride
            self.skip_add = nn.quantized.FloatFunctional() 
    
        def forward(self, x: Tensor) -> Tensor:
            identity = x
    
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)
    
            out = self.conv2(out)
            out = self.bn2(out)
    
            if self.downsample is not None:
                identity = self.downsample(x)
    
            out = self.skip_add.add(out, identity)
            out = self.relu(out)
    
            return out
    

    Step 2

    Inject this by creating new constructor method for quantized model.

    from torchvision.models.resnet import ResNet18_Weights, _resnet
    from torchvision.models._api import register_model, Weights, WeightsEnum
    from torchvision.models._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
    
    @register_model()
    @handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1))
    def quantizedresnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any):
        weights = ResNet18_Weights.verify(weights)
    
        return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
    

    Step 3

    Go back to the cell with

    model = resnet18(num_classes=100, pretrained=False)
    fused_model = copy.deepcopy(model)
    fused_model.eval()
    qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
    fused_model.qconfig = qconfig
    

    And change it with this (see the first line is changed)

    model = quantizedresnet18(num_classes=100, pretrained=False)
    fused_model = copy.deepcopy(model)
    fused_model.eval()
    qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
    fused_model.qconfig = qconfig
    

    Then execute all cells below again.