pythonpytorchrecurrent-neural-networkdata-preprocessing

Issue when padding and packing sequences in LSTM networks using PyTorch


I'm trying to make a simple lstm neural network. I've got time series data which I am splitting into sequences and batches using Pytorch's Dataset and DataLoader. To account for the variable lengths of the sequences in the last batch (since the data runs out), I use padding and packing.

I'm using collate_fn in the dataloader, which looks like this:

def collate_data(batch):
    sequences, targets = zip(*batch)
    
    lens = [len(seq) for seq in sequences]
    print(f"Lens before padding: {lens}")

    padded_seq = pad_sequence(sequences=sequences,batch_first=True,
    padding_value=float(9.99e10))

    print(f"Lens after padding: {[len(seq) for seq in padded_seq]}")

    padded_targets = pad_sequence(sequences=targets,batch_first=True,
    padding_value=float(9.99e10))

    packed_batch=pack_padded_sequence(padded_seq,lengths=lens,batch_first=True,\
    enforce_sorted=False)

    print(f"Packed batch lengths: {packed_batch.batch_sizes}")

    return packed_batch, padded_targets

My issue is when I try to unpack the values in the forward method of my neural network. My forward method looks like this:

 def forward(self,x ):
        lstm = self.lstm
        batch_size = self.batch_size

        h0 = torch.zeros(self.num_layers,batch_size,self.hidden_size,)   
        c0 = torch.zeros(self.num_layers,batch_size,self.hidden_size,)

        packed_lstm_out, (hn,cn) = lstm(x, (h0,c0))
        
        print(f"lstm_out size: {packed_lstm_out.data.size}")        
        unpacked_lstm_out = unpack_sequence(packed_sequences=packed_lstm_out,)        
        print(f"Unpacked lengths: {[len(seq) for seq in unpacked_lstm_out]}")

        unpacked_lstm_tensor = torch.stack(unpacked_lstm_out,dim=0).float().\
        requires_grad_(True)

        print(unpacked_lstm_tensor.shape)

        output = self.fc1(unpacked_lstm_tensor[:,-1,:])

        return output

However I am getting an error when I try to use torch.stack(unpacked_lstm_out, dim=0) since the sizes are different. This is only occurring on the last batch, which should be padded.

I've added print statements, which outputs this, for the last batch:

Lens before padding: [10, 10, 10, 10, 10, 10, 10, 9, 8, 7, 6, 5]
Lens after padding: [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]
Packed batch lengths: tensor([12, 12, 12, 12, 12, 11, 10,  9,  8,  7])
lstm_out size: torch.Size([105, 16])
Unpacked lengths: [10, 10, 10, 10, 10, 10, 10, 9, 8, 7, 6, 5]

My understanding is that the issue occurs when I use pack_padded_sequence(), but I don't know how to fix it or why it occurs.

Does anyone know how to fix this issue so that all the tensors are the same size after unpacking them in the forward function?


Solution

  • unpack_sequence() also removes the padding, so the sequences are no longer padded to the same length, as you've said. unpacked_lstm_out is a list of length batch_size, where each element is shaped (sample's sequence length, hidden_size).

    output = self.fc1(unpacked_lstm_tensor[:,-1,:])

    I think it should be:

    output_n = torch.stack([seq[-1, :] for seq in unpacked_lstm_out], dim=0)
    
    output = self.fc1(output_n)
    

    This pulls out the final frame from each sequence in the unpacked list, and stacks them into a tensor shaped (batch_size, hidden_size).

    To account for the variable lengths of the sequences in the last batch (since the data runs out)

    Perhaps you could drop those shorter sequences? You'll lose a bit of the tail data, but it'll mean you can avoid handling variable sequence lengths. Alternatively, you could write a custom data sampler that batches equally-sized sequences together (example below). In both cases, you can use regular tensors (rather than packed sequences), which are simpler and work seamlessly with other torch.nn layers.


    Code I've previously used to draw batches from a dataset, where each batch has sequences of the same length. For example, the first batch might be (batch_size, sequences that all have length 5), and the next random batch could be (batch_size, sequences that all have length 13). SameLengthsBatchSampler yields the indices of the samples to use, not the samples themselves. It's supplied to the batch_sampler= parameter of DataLoader().

    from torch.utils.data import Sampler
    
    #Batch sampler: yields (B, sample indices where each sample has same seq_length).
    class SameLengthsBatchSampler(Sampler):
        def __init__(self, sentences, batch_size, drop_last=False):
            lengths = [len(sentence) for sentence in sentences]
            unique_lengths, counts = np.unique(lengths, return_counts=True)
            
            #Only consider sequence lengths where count >= batch_size
            unique_lengths = unique_lengths[counts >= batch_size]
            counts = counts[counts >= batch_size]
            
            same_lens_dict = {}
            for length in unique_lengths:
                same_lens_dict[length] = np.argwhere(lengths == length).ravel()
            
            self.same_lens_dict = same_lens_dict #samples organised by sequence len
            self.unique_lengths = unique_lengths
            self.batch_size = batch_size
            self.drop_last = drop_last
        
        def __len__(self):
            for i, _ in enumerate(self.__iter__()):
                pass
            return i
            
        def __iter__(self):
            for seq_len in self.unique_lengths[torch.randperm(len(self.unique_lengths))]:
                #All samples with this length
                sample_indices = torch.tensor(self.same_lens_dict[seq_len])
                shuffled_ixs = sample_indices[torch.randperm(len(sample_indices))]
            
                #Split tensor into batch-sized tensors
                indices_per_batch = torch.split(shuffled_ixs, self.batch_size)
                
                if self.drop_last and len(indices_per_batch[-1]) < self.batch_size:
                    indices_per_batch = indices_per_batch[:-1]
                
                if False: #print batch details
                    print('sequence_length={} | yielding {} samples over {} batches'.format(
                        seq_len, len(sample_indices), len(indices_per_batch)
                    ))
                
                #yield over the batch indices
                yield from indices_per_batch
    
    #
    # Batch data
    #
    batch_size = 32
    
    train_loader = DataLoader(
        train_dataset,
        batch_sampler=SameLengthsBatchSampler(trn_sentences, batch_size)
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_sampler=SameLengthsBatchSampler(val_sentences, batch_size)
    )
    

    Some discussion of this type of functionality is available here.