pythonpytorchpytorch-lightning

PyTorch Lightning does not terminate on Mac OS Metal (M4 Max) when num_workers > 0


I've been trying to train some basic models using PyTorch Lightning on an M4 Max Mac Studio. While the training itself goes without hitch, there appears to be a problem when attempting to terminate the program at the end. After training has completed, the following (expected) message is outputted:

Epoch 4: 100%|██████████| 782/782 [00:03<00:00, 249.01it/s, v_num=14]
Trainer.fit stopped: `max_epochs=5` reached.

However, after this, the program simply hangs forever, and I have to perform a Ctrl+C in order to formally stop it. Here is the code. I've adapted it from this PyTorch Lightning tutorial (original file), but I've removed the callbacks (model checkpointing and early stopping) and the optimizer scheduler to simplify things.

import torch
import torch.nn as nn
import torch.nn.functional as F

import lightning as L

from torch.utils.data import DataLoader
from torchvision import datasets, transforms


# Define the CNN architecture
class CIFAR10CNN(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 64 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = (y_hat.argmax(1) == y).float().mean()

        self.log("train_loss", loss)
        self.log("train_acc", acc)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = (y_hat.argmax(1) == y).float().mean()
        self.log("val_loss", loss)
        self.log("val_acc", acc)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = (y_hat.argmax(1) == y).float().mean()
        self.log("test_loss", loss)
        self.log("test_acc", acc)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        #     optimizer, mode="min", factor=0.1, patience=5
        # )
        return optimizer

if __name__ == "__main__":
    # Data transformations
    transform_train = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    transform_test = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    # Load CIFAR-10 dataset
    train_dataset = datasets.CIFAR10(
        root="./data", train=True, download=True, transform=transform_train
    )
    val_dataset = datasets.CIFAR10(
        root="./data", train=False, download=True, transform=transform_test
    )

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=14, persistent_workers=True)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=14, persistent_workers=True)

    # Initialize the model
    model = CIFAR10CNN()

    # Initialize the Trainer
    trainer = L.Trainer(
        max_epochs=5,
        accelerator="mps",
        devices="auto",
    )

    # Train the model
    trainer.fit(model, train_loader, val_loader)

One solution I've found is to set num_workers to 0, but this really slows training down. I suspect it has something to do with multiprocessing and how it spins up the threads.

Another interesting thing is that PyTorch Lightning's own very basic tutorial runs without a hitch, even when setting num_workers=14.

I've also tried removing the validation loader, but there was no change, it still hangs. I've tried running both on PyCharm and in the bare terminal, and no change either. The error still happens when setting num_workers to 1. Has anyone else run into this and provide some pointers?

Other details and package versions:

Thanks!


Solution

  • Turns out, PyTorch Lightning had nothing to do with this at all. Even just a normal vanilla PyTorch loop was causing issues terminating the program. An os._exit(0) works, but the more permanent solution I found was to update my PyTorch installation to the nightly build.

    pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
    

    At the time of writing, this installs a dev-build of the next PyTorch version of 2.8, which seems to have solved the issue. I do not know the root of the issue, but it's possible that the new M4 Max chip + macOS 15.5 probably caused some bugs to surface with how PyTorch terminates multiprocessing.