pythondeep-learningpytorchpytorch-lightning

When to use prepare_data vs setup in pytorch lightning?


Pytorch's docs on Dataloaders only say, in the code

def prepare_data(self):
    # download
    ...

and

def setup(self, stage: Optional[str] = None):
    # Assign train/val datasets for use in dataloaders

Please explain the intended separation between prepare_data and setup, what callbacks may occur between them, and why put something in one over the other.


Solution

  • If you look at the pseudo for the Trainer.fit function provided in the documentation page of LightningModule at § Hooks, you can read:

    def fit(self):
        if global_rank == 0:
            # prepare data is called on GLOBAL_ZERO only
            prepare_data()                                 ## <-- prepare_data
    
        configure_callbacks()
    
        with parallel(devices):
            # devices can be GPUs, TPUs, ...
            train_on_device(model)
    
    
    def train_on_device(model):
        # called PER DEVICE
        on_fit_start()
        setup("fit")                                       ## <-- setup
        configure_optimizers()
    
        # the sanity check runs here
    
        on_train_start()
        for epoch in epochs:
            fit_loop()
        on_train_end()
    
        on_fit_end()
        teardown("fit")
    

    You can see prepare_data being called only for global_rank == 0, i.e. it is only called by a single processor. It turns out you can read from the documentation description of prepare_data:

    LightningModule.prepare_data()
    Use this to download and prepare data. Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning ensures this method is called only within a single process, so you can safely add your downloading logic within.

    Whereas setup is called on all processes as you can read from the pseudo-code above as well as its documentation description:

    LightningModule.setup(stage=None)
    Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.