In this example, I wish the z_proto
could be global for different GPUs. However, in the data parallel mode, it is split into different GPUs as well. How to solve such a problem? Thank you.
class SequencePrototypeTokenClassification(nn.Module):
def __init__(self,seq_model, label_num):
super(SequencePrototypeTokenClassification, self).__init__()
self.seq_model = seq_model
self.label_num = label_num
def forward(self, input_ids, token_type_ids, attention_mask, labels, z_proto, n_query, target_inds):
z, _ = self.seq_model(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
z_dim = z.size(-1)
zq = z.squeeze().view(-1, z_dim)
dists = euclidean_dist(zq, z_proto)
log_p_y = F.log_softmax(-dists, dim=1).view(-1, self.label_num)
loss_val = -log_p_y.gather(1, self.target_inds).squeeze().view(-1).mean()
_, y_hat = log_p_y.max(1)
return loss_val, y_hat
It turns out the DataParallel
would only replicate the nn.Parameter
of the nn.Module
. So I random initialized a nn.Parameter
named z_proto
in the module and copy the value of tensor z_proto
into the parameter. Then the parameter is replicated into 4 GPUs.