I want to train a VGG16 model with Horovod PyTorch on 4 GPUs. Instead of using the CIFAR10 dataset of torch vision.datasets.CIFAR10, I would like to split the dataset on my own. So I downloaded the dataset from the official website and split the dataset. This is how I split the data:
if __name__ == '__main__':
import pickle
train_data, train_label = [], []
test_data, test_label = [], []
for i in range(1, 6):
with open('/Users/wangqipeng/Downloads/cifar-10-batches-py/data_batch_{}'.format(i), 'rb') as f:
b = pickle.load(f, encoding='bytes')
train_data.extend(b[b'data'].tolist()[:8000])
train_label.extend(b[b'labels'][:8000])
test_data.extend(b[b'data'].tolist()[8000:])
test_label.extend(b[b'labels'][8000:])
num_train = len(train_data)
num_test = len(test_data)
print(num_train, num_test)
train_data = np.array(train_data)
test_data = np.array(test_data)
for i in range(4):
with open('/Users/wangqipeng/Downloads/train_{}'.format(i), 'wb') as f:
d = {b'data': train_data[int(0.25 * i * num_train): int(0.25 * (i + 1) * num_train)],
b'labels': train_label[int(0.25 * i * num_train): int(0.25 * (i + 1) * num_train)]}
pickle.dump(d, f)
with open('/Users/wangqipeng/Downloads/test'.format(i), 'wb') as f:
d = {b'data': test_data,
b'labels': test_label}
pickle.dump(d, f)
However, I found that if I use the dataset that I downloaded from the official website, there will be an exploding gradient problem. I found that the loss will increase and be "nan" after several iterations. This is how I read the dataset:
class DataSet(torch.utils.data.Dataset):
def __init__(self, path):
self.dataset = unpickle(path)
def __getitem__(self, index):
data = torch.tensor(
self.dataset[b'data'][index], dtype=torch.float32).resize(3, 32, 32)
return data, torch.tensor(self.dataset[b'labels'][index])
def __len__(self):
return len(self.dataset[b'data'])
train_dataset = DataSet("./cifar10/train_" + str(hvd.rank()))
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, sampler=None, **kwargs)
If I print the loss of every iteration, I see something like this:
Mon Nov 9 11:28:29 2020[0]<stdout>:epoch 0 iter[ 0 / 313 ] loss 7.725658416748047 accuracy 5.46875
Mon Nov 9 11:28:29 2020[0]<stdout>:epoch 0 iter[ 1 / 313 ] loss 15.312677383422852 accuracy 8.59375
Mon Nov 9 11:28:29 2020[0]<stdout>:epoch 0 iter[ 2 / 313 ] loss 16.333066940307617 accuracy 9.375
Mon Nov 9 11:28:30 2020[0]<stdout>:epoch 0 iter[ 3 / 313 ] loss 15.549728393554688 accuracy 9.9609375
Mon Nov 9 11:28:30 2020[0]<stdout>:epoch 0 iter[ 4 / 313 ] loss 14.090616226196289 accuracy 9.843750298023224
Mon Nov 9 11:28:31 2020[0]<stdout>:epoch 0 iter[ 5 / 313 ] loss 12.310989379882812 accuracy 9.63541641831398
Mon Nov 9 11:28:31 2020[0]<stdout>:epoch 0 iter[ 6 / 313 ] loss 11.578919410705566 accuracy 9.15178582072258
Mon Nov 9 11:28:31 2020[0]<stdout>:epoch 0 iter[ 7 / 313 ] loss 13.210229873657227 accuracy 8.7890625
Mon Nov 9 11:28:32 2020[0]<stdout>:epoch 0 iter[ 8 / 313 ] loss 764.713623046875 accuracy 9.28819477558136
Mon Nov 9 11:28:32 2020[0]<stdout>:epoch 0 iter[ 9 / 313 ] loss 4.590414250749922e+20 accuracy 8.984375
Mon Nov 9 11:28:32 2020[0]<stdout>:epoch 0 iter[ 10 / 313 ] loss nan accuracy 9.446022659540176
Mon Nov 9 11:28:33 2020[0]<stdout>:epoch 0 iter[ 11 / 313 ] loss nan accuracy 10.09114608168602
Mon Nov 9 11:28:33 2020[0]<stdout>:epoch 0 iter[ 12 / 313 ] loss nan accuracy 10.39663478732109
However, If I use the dataset from torchvision, everything will be fine:
train_dataset = \
datasets.CIFAR10(args.train_dir, download=True,
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]))
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, sampler=train_sampler, **kwargs)
There can also be something wrong with the DistributedSampler. But I think the DistributedSampler only serves to split the data. I don't know whether the DistributedSampler can be a reason to this problem.
Is there something wrong with the way I read the CIFAR10 dataset? Or is there something wrong with the way I "reshape" it? Thanks for your help!
Maybe it is because I did not normalize the dataset. Thanks for everyone's help!