I thought this was a simple question but I couldn't find an answer.
I want a member variable of a pytorch module to be saved/loaded with model state_dict. I can do that in init with the following line.
self.register_buffer('loss_weight', torch.tensor(loss_weight))
But what if loss_weight is a dict object? Is it allowed? If so, how can I convert it to a tensor?
When tried, I got an error "Could not infer dtype of dict."
Per the docs, the name
argument must be a string, and the tensor
argument must be a pytorch tensor.
If you have a dict of buffers, you could consider using a dedicated nn.Module
for that purpose. Something like this:
class BufferDict(nn.Module):
def __init__(self, input_dict):
super().__init__()
for k,v in input_dict.items():
self.register_buffer(k, v)
input_dict = {'a' : torch.randn(4), 'b' : torch.randn(5)}
bd = BufferDict(input_dict)
bd.state_dict()
> OrderedDict([('a', tensor([ 0.1908, 1.6965, -0.3710, 0.4551])),
('b', tensor([-0.6943, -0.0534, 0.1779, 1.3607, -0.2236]))])