In PyTorch, it is possible to save model checkpoints as follows:
import torch
# Create a model
model = torch.nn.Sequential(
torch.nn.Linear(1, 50),
torch.nn.Tanh(),
torch.nn.Linear(50, 1)
)
# ... some training here
# Save checkpoint
torch.save(network.state_dict(), 'checkpoint.pt')
During my training procedure, I save a checkpoint every 100 epochs or so. Currently this results in a folder with many files, e.g.
checkpoint0.pt
checkpoint100.pt
checkpoint200.pt
I was wondering if it was possible to append checkpoints to an existing file, so I don't clutter my disk with small files but instead have only a single file called checkpoints.pt
. I currently have implemented this as follows:
import torch
# Create a model
model = torch.nn.Sequential(
torch.nn.Linear(1, 50),
torch.nn.Tanh(),
torch.nn.Linear(50, 1)
)
# ... some training here
# Save 1st checkpoint
data = {'0': model.state_dict()}
torch.save(data, 'checkpoints.pt')
# ... some training here
# Save 2nd checkpoint
data = torch.load('checkpoints.pt')
data['100'] = model.state_dict()
torch.save(data, 'checkpoints.pt')
print(torch.load('checkpoints.pt'))
But the problem is it requires loading the existing file in memory before appending a new checkpoint, which is memory intensive especially considering that I have 100s of checkpoints. Is there a way to do this (or something similar) without having to load the existing checkpoints back into memory?
See this post on multiple pickled objects in the same file. The short of it is that pytorch checkpointing is backended by pickle
, so if you use a trivial pickle
wrapper rather than the default torch.save
you can easily accomplish this:
import _pickle as pickle # _pickle is the newer updated version (cpickle) I believe, with improved C-backend
def append_save(network,path):
with open(path,"ab") as f:
pickle.dump(network.state_dict(),f)
Now, you'll have to read each model state-dict serially from the file.
def read_checkpoints(path):
checkpoints = []
with open(path,"rb") as f:
while True:
try:
checkpoints.append(pickle.load(f))
except EOFError:
break