kerassdkpytorchmetricsdcgan

DCGANs discriminator accuracy metric using PyTorch


I am implementing DCGANs using PyTorch.

It works well in that I can get reasonable quality generated images, however now I want to evaluate the health of the GAN models by using metrics, mainly the ones introduced by this guide https://machinelearningmastery.com/practical-guide-to-gan-failure-modes/

Their implementation uses Keras which SDK lets you define what metrics you want when you compile the model, see https://keras.io/api/models/model/. In this case the accuracy of the discriminator, i.e. percentage of when it successfully identifies an image as real or generated.

With the PyTorch SDK, I can't seem to find a similar feature that would help me easily acquire this metric from my model.

Does Pytorch provide the functionality to be able to define and extract common metrics from a model?


Solution

  • Pure PyTorch does not provide metrics out of the box, but it is very easy to define those yourself.

    Also there is no such thing as "extracting metrics from model". Metrics are metrics, they measure (in this case accuracy of discriminator), they are not inherent to the model.

    Binary accuracy

    In your case, you are looking for binary accuracy metric. Below code works with either logits (unnormalized probability outputed by discriminator, probably last nn.Linear layer without activation) or probabilities (last nn.Linear followed by sigmoid activation):

    import typing
    import torch
    
    
    class BinaryAccuracy:
        def __init__(
            self,
            logits: bool = True,
            reduction: typing.Callable[
                [
                    torch.Tensor,
                ],
                torch.Tensor,
            ] = torch.mean,
        ):
            self.logits = logits
            if logits:
                self.threshold = 0
            else:
                self.threshold = 0.5
    
            self.reduction = reduction
    
        def __call__(self, y_pred, y_true):
            return self.reduction(((y_pred > self.threshold) == y_true.bool()).float())
    

    Usage:

    metric = BinaryAccuracy()
    target = torch.randint(2, size=(64,))
    outputs = torch.randn(size=(64, 1))
    
    print(metric(outputs, target))
    

    PyTorch Lightning or other third party

    You can also use PyTorch Lightning or other framework on top of PyTorch which defines metrics like accuracy