This is a snippet of my code in PyTorch, my jupiter notebook stuck when I used num_workers > 0, I spent a lot on this problem without any answer. I do not have a GPU and I work only with a CPU.
class IndexedDataset(Dataset):
def __init__(self,data,targets, test=False):
self.dataset = data
if not test:
self.labels = targets.numpy()
self.mask = np.concatenate((np.zeros(NUM_LABELED), np.ones(NUM_UNLABELED)))
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
image = self.dataset[idx]
return image, self.labels[idx]
def display(self, idx):
plt.imshow(self.dataset[idx], cmap='gray')
plt.show()
train_set = IndexedDataset(train_data, train_target, test = False)
test_set = IndexedDataset(test_data, test_target, test = True)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, num_workers=2)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, num_workers=2)
Any help, appreciated.
When num_workers
is greater than 0, PyTorch uses multiple processes for data loading.
Jupyter notebooks have known issues with multiprocessing.
One way to resolve this is not to use Jupyter notebooks - just write a normal .py file and run it via command-line.
Or try use what's suggested here: Jupyter notebook never finishes processing using multiprocessing (Python 3).