I've been trying to train some basic models using PyTorch Lightning on an M4 Max Mac Studio. While the training itself goes without hitch, there appears to be a problem when attempting to terminate the program at the end. After training has completed, the following (expected) message is outputted:
Epoch 4: 100%|██████████| 782/782 [00:03<00:00, 249.01it/s, v_num=14]
Trainer.fit stopped: `max_epochs=5` reached.
However, after this, the program simply hangs forever, and I have to perform a Ctrl+C in order to formally stop it. Here is the code. I've adapted it from this PyTorch Lightning tutorial (original file), but I've removed the callbacks (model checkpointing and early stopping) and the optimizer scheduler to simplify things.
import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning as L
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# Define the CNN architecture
class CIFAR10CNN(L.LightningModule):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 4 * 4, 512)
self.fc2 = nn.Linear(512, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = x.view(-1, 64 * 4 * 4)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
acc = (y_hat.argmax(1) == y).float().mean()
self.log("train_loss", loss)
self.log("train_acc", acc)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
acc = (y_hat.argmax(1) == y).float().mean()
self.log("val_loss", loss)
self.log("val_acc", acc)
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
acc = (y_hat.argmax(1) == y).float().mean()
self.log("test_loss", loss)
self.log("test_acc", acc)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
# optimizer, mode="min", factor=0.1, patience=5
# )
return optimizer
if __name__ == "__main__":
# Data transformations
transform_train = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
transform_test = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform_train
)
val_dataset = datasets.CIFAR10(
root="./data", train=False, download=True, transform=transform_test
)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=14, persistent_workers=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=14, persistent_workers=True)
# Initialize the model
model = CIFAR10CNN()
# Initialize the Trainer
trainer = L.Trainer(
max_epochs=5,
accelerator="mps",
devices="auto",
)
# Train the model
trainer.fit(model, train_loader, val_loader)
One solution I've found is to set num_workers
to 0, but this really slows training down. I suspect it has something to do with multiprocessing and how it spins up the threads.
Another interesting thing is that PyTorch Lightning's own very basic tutorial runs without a hitch, even when setting num_workers=14
.
I've also tried removing the validation loader, but there was no change, it still hangs. I've tried running both on PyCharm and in the bare terminal, and no change either. The error still happens when setting num_workers
to 1. Has anyone else run into this and provide some pointers?
Other details and package versions:
Thanks!
Turns out, PyTorch Lightning had nothing to do with this at all. Even just a normal vanilla PyTorch loop was causing issues terminating the program. An os._exit(0)
works, but the more permanent solution I found was to update my PyTorch installation to the nightly build.
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
At the time of writing, this installs a dev-build of the next PyTorch version of 2.8, which seems to have solved the issue. I do not know the root of the issue, but it's possible that the new M4 Max chip + macOS 15.5 probably caused some bugs to surface with how PyTorch terminates multiprocessing.