I defined a GAN model and I want to evaluate it using FID score. I have 1 channel images which are mnist dataset but this method wants 3 channels images. How can I do to solve this problem?
try to split it into 3 channels before evaulating.
import torch
import torchvision
from torcheval import metrics
# Load the MNIST dataset
mnist_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True)
# Convert the 1 channel images to 3 channel images
mnist_dataset.data = mnist_dataset.data.unsqueeze(1)
mnist_dataset.data = mnist_dataset.data.repeat(1, 3, 1, 1)
# Calculate the FID score
fid_score = metrics.FrechetInceptionDistance()(mnist_dataset.data)
# Evaluate the FID score
print('FID score:', fid_score)`enter code here`