pythonpytorchtqdm

Unexpected printouts interfere with tqdm progress bar in PyTorch training run


I am trying to understand how the progress bar using tqdm works exactly. I have some code that looks as follows:

import torch
import torchvision
print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")

load_data()
manual_transforms = transforms.Compose([])
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders()

# them within the main function I have placed the train function that exists in the `engine.py` file
def main():

      results = engine.train(model=model,
        train_dataloader=train_dataloader,
        test_dataloader=test_dataloader,
        optimizer=optimizer,
        loss_fn=loss_fn,
        epochs=5,
        device=device)

and the engine.train() function includes the following code for epoch in tqdm(range(epochs)): then, the training for each batch takes place to visualize the progress of the training. Each time the tqdm runs for each step it prints also the following statements:

print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")

So finally, my question is why this is happening. How does the main function have access to these global statements and how can avoid printing everything in each loop?


Solution

  • What you are noticing has actually nothing to do with tqdm, but rather with the inner workings of PyTorch (in particular, the DataLoader's num_workers attribute) and Python's underlying multiprocessing framework. Here is a minimum working example that should reproduce your problem:

    from contextlib import suppress
    from multiprocessing import set_start_method
    import torch
    from torch.utils.data import DataLoader, Dataset
    from tqdm import tqdm
    print("torch version:", torch.__version__)
    
    class DummyData(Dataset):
        def __len__(self): return 256
        def __getitem__(self, i): return i
    
    def main():
        for batch in tqdm(DataLoader(DummyData(), batch_size=16, num_workers=4)):
            pass  # Do something
        
    if __name__ == "__main__":
        # Enforce "spawn" method (e.g. on Linux) for subprocess creation to
        # reproduce problem (suppress error for reruns in same interpreter)
        with suppress(RuntimeError): set_start_method("spawn")
        main()
    

    If you run this piece of code, you should see your PyTorch version number be printed exactly 4 times, messing up your tqdm progress bar. It is not a coincidence that this number is the same as num_workers (which you can easily check by changing this number).

    What happens is the following:

    The behavior is documented here, along with potential mitigations. The one that would work for you, I guess, is:

    Wrap most of you main script’s code within if __name__ == '__main__': block, to make sure it doesn’t run again

    So, either

    1. move the print() calls to the beginning of your if __name__ == '__main__': block,
    2. move the print() calls to the beginning of your main() function, or
    3. remove the print() calls.

    Alternatively, but this is probably not what you want, you can set num_workers=0, which will disable the underlying use of multiprocessing altogether (but in this way you will also lose the benefits of parallelization). Note that you should probably also move other function calls (such as load_data()) into the if __name__ == '__main__': block or into the main() function to avoid multiple unintended executions.