I am trying to understand how the progress bar using tqdm
works exactly. I have some code that looks as follows:
import torch
import torchvision
print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")
load_data()
manual_transforms = transforms.Compose([])
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders()
# them within the main function I have placed the train function that exists in the `engine.py` file
def main():
results = engine.train(model=model,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
optimizer=optimizer,
loss_fn=loss_fn,
epochs=5,
device=device)
and the engine.train()
function includes the following code for epoch in tqdm(range(epochs)):
then, the training for each batch takes place to visualize the progress of the training. Each time the tqdm runs for each step it prints also the following statements:
print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")
So finally, my question is why this is happening. How does the main function have access to these global statements and how can avoid printing everything in each loop?
What you are noticing has actually nothing to do with tqdm
, but rather with the inner workings of PyTorch (in particular, the DataLoader
's num_workers
attribute) and Python's underlying multiprocessing
framework. Here is a minimum working example that should reproduce your problem:
from contextlib import suppress
from multiprocessing import set_start_method
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
print("torch version:", torch.__version__)
class DummyData(Dataset):
def __len__(self): return 256
def __getitem__(self, i): return i
def main():
for batch in tqdm(DataLoader(DummyData(), batch_size=16, num_workers=4)):
pass # Do something
if __name__ == "__main__":
# Enforce "spawn" method (e.g. on Linux) for subprocess creation to
# reproduce problem (suppress error for reruns in same interpreter)
with suppress(RuntimeError): set_start_method("spawn")
main()
If you run this piece of code, you should see your PyTorch version number be printed exactly 4 times, messing up your tqdm
progress bar. It is not a coincidence that this number is the same as num_workers
(which you can easily check by changing this number).
What happens is the following:
num_workers
is > 0, then subprocesses are launched for the workers.set_start_method()
).if __name__ == "__main__":
block. This includes your print()
calls on top of the script.The behavior is documented here, along with potential mitigations. The one that would work for you, I guess, is:
Wrap most of you main script’s code within
if __name__ == '__main__':
block, to make sure it doesn’t run again
So, either
print()
calls to the beginning of your if __name__ == '__main__':
block,print()
calls to the beginning of your main()
function, orprint()
calls.Alternatively, but this is probably not what you want, you can set num_workers=0
, which will disable the underlying use of multiprocessing
altogether (but in this way you will also lose the benefits of parallelization). Note that you should probably also move other function calls (such as load_data()
) into the if __name__ == '__main__':
block or into the main()
function to avoid multiple unintended executions.