pythonneural-networkpytorchmse

How implement a Mean Standard Error (MSE) metric for NNI (Neural network intelligence) in pytorch?


I am somewhat new to pytorch since I have been using Keras for some years. Now I want to run a network architecture search (NAS) based on DARTS: Differentiable Architecture Search (see https://nni.readthedocs.io/en/stable/NAS/DARTS.html) and it is based on pytorch.

All examples available use accuracy as a metric, but I would need to calculate MSE. This is one of the examples available:

DartsTrainer(model,
                               loss=criterion,
                               metrics=lambda output, target: accuracy(output, target, topk=(1,)),
                               optimizer=optim,
                               num_epochs=args.epochs,
                               dataset_train=dataset_train,
                               dataset_valid=dataset_valid,
                               batch_size=args.batch_size,
                               log_frequency=args.log_frequency,
                               unrolled=args.unrolled,
                               callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")]) 

# where the accuracy is defined in a separate function:

def accuracy(output, target, topk=(1,)):
    # Computes the precision@k for the specified values of k
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    # one-hot case
    if target.ndimension() > 1:
        target = target.max(1)[1]

    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = dict()
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
    return res

As I see in pytorch it is more complicated to calculate metrics then in Keras. Can someone help please?

As a trial, I wrote this code:

def accuracy_mse(output, target):
    batch_size = target.size(0)
    
    diff = torch.square(output.t()-target)/batch_size
    diff = diff.sum()

    res = dict()

    res["acc_mse"] = diff
    return res    

It seems to be working, but I am not 100% sure about it ...


Solution

  • Finally I figured out that the transpose (.t() ) wac causing the problem, so the final code is:

    def accuracy_mse(output, target):
       
        """ Computes the mse """
        batch_size = target.size(0)
        
        diff = torch.square(output-target)/batch_size
        diff = diff.sum()
        res = dict()
    
        res["mse"] = diff
    
        return res