I have a model trained in PyTorch, saved in .pth format. Is it possible to use and load that model in Flux.jl? I looked around but did not see this mentioned anywhere in the Flux docs.
The only way I can think of is
.pth
to .onnx
byimport torch.onnx
import torchvision
import torch
dummy_input = #...
model = #...
state_dict = torch.load('model.pth')
model.load_state_dict(state_dict)
torch.onnx.export(model, dummy_input, "model.onnx")
.onnx
using ONNX.jl. It seems that this library is currently under re-construction but old API might work for you. Double check it, it seems that there could be descrepancies after loading the model.Also, this discussion is relevant: https://github.com/FluxML/ML-Coordination-Tracker/issues/10