pytorchpytorch-lightning

How to manually specify checkpoint path in PyTorchLightning


Currently I am using TensorBoardLogger for all my needs and it's perfect, but i do not like how it handles checkpoint naming. I'd prefer to be able to specify the filename and the folder where to put the checkpoint manually, how should i do that?


Solution

  • Yes, it is possible thanks to the ModelCheckPoint callback:

    from pytorch_lightning.callbacks import ModelCheckpoint
    
    checkpoint_callback = ModelCheckpoint(
        dirpath="best_models",
        filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
    )
    trainer = Trainer(callbacks=[checkpoint_callback])
    

    will create a checkpoint in the directory best_models/epoch=2-val_loss=0.02-other_metric=0.03.ckpt for example