Here is a sample code of my training function(unnecessary parts are deleted):
I am trying to save my model data_gen in the torch.save(), and after running the train_dmc function, I can find the checkpoint file in the directory.
def train_dmc(loader,loss):
data_gen = DataGenerator().to(device)
data_gen_optimizer = optim.Rprop(para_list, lr=lrate)
savepath='/content/drive/MyDrive/'+loss+'checkpoint.t7'
state = {
'epoch': epoch,
'model_state_dict': data_gen.state_dict(),
'optimizer_state_dict': data_gen_optimizer.state_dict(),
'data loss': data_loss,
'latent_loss':latent_loss
}
torch.save(state,savepath)
My question is that how to load the checkpoint file to continue training if Google Colab disconnects.
Should I load data_gen or train_dmc(), it is my first time using this and I am really confused because the data_gen is defined inside another function. Hope someone can help me with explanation
data_gen.load_state_dict(torch.load(PATH))
data_gen.eval()
#or
train_dmc.load_state_dict(torch.load(PATH))
train_dmc.eval()
As the state
variable is a dictionary, So try saving it as:
with open('/content/checkpoint.t7', 'wb') as handle:
pickle.dump(state, handle, protocol=pickle.HIGHEST_PROTOCOL)
Initiate your model class as data_gen = DataGenerator().to(device)
.
And load the checkpoint file as:
import pickle
file = open('/content/checkpoint.t7', 'rb')
loaded_state = pickle.load(file)
Then you can load the state_dict
using data_gen = loaded_state['model_state_dict']
. This will load the state_dict to the model class!