pythonpytorchtorchcheckpoint

How to load a saved model defined in a function in PyTorch in Colab?


Here is a sample code of my training function(unnecessary parts are deleted):

I am trying to save my model data_gen in the torch.save(), and after running the train_dmc function, I can find the checkpoint file in the directory.

def train_dmc(loader,loss):


 
  data_gen = DataGenerator().to(device)

  data_gen_optimizer = optim.Rprop(para_list, lr=lrate)


  savepath='/content/drive/MyDrive/'+loss+'checkpoint.t7'
  state = {
            'epoch': epoch,
            'model_state_dict': data_gen.state_dict(),
            'optimizer_state_dict': data_gen_optimizer.state_dict(),
            'data loss': data_loss,
            'latent_loss':latent_loss
            }
  torch.save(state,savepath)

My question is that how to load the checkpoint file to continue training if Google Colab disconnects.

Should I load data_gen or train_dmc(), it is my first time using this and I am really confused because the data_gen is defined inside another function. Hope someone can help me with explanation

data_gen.load_state_dict(torch.load(PATH))
data_gen.eval()

#or

train_dmc.load_state_dict(torch.load(PATH))
train_dmc.eval()

Solution

  • As the state variable is a dictionary, So try saving it as:

    with open('/content/checkpoint.t7', 'wb') as handle:
        pickle.dump(state, handle, protocol=pickle.HIGHEST_PROTOCOL)
    

    Initiate your model class as data_gen = DataGenerator().to(device).

    And load the checkpoint file as:

    import pickle
    file = open('/content/checkpoint.t7', 'rb')
    loaded_state = pickle.load(file)
    

    Then you can load the state_dict using data_gen = loaded_state['model_state_dict']. This will load the state_dict to the model class!