Pytorch's docs on Dataloaders only say, in the code
def prepare_data(self):
# download
...
and
def setup(self, stage: Optional[str] = None):
# Assign train/val datasets for use in dataloaders
Please explain the intended separation between prepare_data
and setup
, what callbacks may occur between them, and why put something in one over the other.
If you look at the pseudo for the Trainer.fit
function provided in the documentation page of LightningModule
at § Hooks, you can read:
def fit(self):
if global_rank == 0:
# prepare data is called on GLOBAL_ZERO only
prepare_data() ## <-- prepare_data
configure_callbacks()
with parallel(devices):
# devices can be GPUs, TPUs, ...
train_on_device(model)
def train_on_device(model):
# called PER DEVICE
on_fit_start()
setup("fit") ## <-- setup
configure_optimizers()
# the sanity check runs here
on_train_start()
for epoch in epochs:
fit_loop()
on_train_end()
on_fit_end()
teardown("fit")
You can see prepare_data
being called only for global_rank == 0
, i.e. it is only called by a single processor. It turns out you can read from the documentation description of prepare_data
:
LightningModule.prepare_data()
Use this to download and prepare data. Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning ensures this method is called only within a single process, so you can safely add your downloading logic within.
Whereas setup
is called on all processes as you can read from the pseudo-code above as well as its documentation description:
LightningModule.setup(stage=None)
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.