pythonpytorchstate-dict

register_buffer a dict object in PyTorch


I thought this was a simple question but I couldn't find an answer.

I want a member variable of a pytorch module to be saved/loaded with model state_dict. I can do that in init with the following line.

        self.register_buffer('loss_weight', torch.tensor(loss_weight))

But what if loss_weight is a dict object? Is it allowed? If so, how can I convert it to a tensor?

When tried, I got an error "Could not infer dtype of dict."


Solution

  • Per the docs, the name argument must be a string, and the tensor argument must be a pytorch tensor.

    If you have a dict of buffers, you could consider using a dedicated nn.Module for that purpose. Something like this:

    class BufferDict(nn.Module):
        def __init__(self, input_dict):
            super().__init__()
            for k,v in input_dict.items():
                self.register_buffer(k, v)
                
    input_dict = {'a' : torch.randn(4), 'b' : torch.randn(5)}
    
    bd = BufferDict(input_dict)
    bd.state_dict()
    > OrderedDict([('a', tensor([ 0.1908,  1.6965, -0.3710,  0.4551])),
                   ('b', tensor([-0.6943, -0.0534,  0.1779,  1.3607, -0.2236]))])