juliaflux.jl

How to use a .pth model in Flux.jl?


I have a model trained in PyTorch, saved in .pth format. Is it possible to use and load that model in Flux.jl? I looked around but did not see this mentioned anywhere in the Flux docs.


Solution

  • The only way I can think of is

    1. Convert .pth to .onnx by
    import torch.onnx
    import torchvision
    import torch
    
    dummy_input = #...
    model = #...
    state_dict = torch.load('model.pth')
    model.load_state_dict(state_dict)
    torch.onnx.export(model, dummy_input, "model.onnx")
    
    1. Load the model from .onnxusing ONNX.jl. It seems that this library is currently under re-construction but old API might work for you. Double check it, it seems that there could be descrepancies after loading the model.

    Also, this discussion is relevant: https://github.com/FluxML/ML-Coordination-Tracker/issues/10