pythonpytorchlstmrecurrent-neural-network

Pytorch LSTM - initializing hidden states during training


I have a Class that contains my LSTM Model and I have a training loop over some Data (=trajectories of a pendulum). When I train the model I have to initialize the hidden state for each timestep. Which confuses me since I thought the power of the LSTM (RNN) would be, that i use the previous hidden state for the next calculation... but I set it to zeros every time. Also the model works very well for predicting the pendulum. (This code is "heavily inspired" from an article that uses it on a really similar problem)

This is the Model Class:

class LSTMmodel(nn.Module):
    
    def __init__(self,input_size,hidden_size_1,hidden_size_2,out_size):
        
        super().__init__()
        self.hidden_size_1 = hidden_size_1
        self.hidden_size_2 = hidden_size_2
        self.input_size = input_size
        self.lstm_1 = nn.LSTM(input_size,hidden_size_1)
        self.lstm_2 = nn.LSTM(hidden_size_1,hidden_size_2)
        self.linear = nn.Linear(hidden_size_2,out_size)
        self.hidden_1 = (
            torch.zeros(1,1,hidden_size_1),
            torch.zeros(1,1,hidden_size_1)
        )
        self.hidden_2 = (
            torch.zeros(1,1,hidden_size_2),
            torch.zeros(1,1,hidden_size_2)
        )
        
    def forward(self,seq):
        lstm_out_1 , self.hidden_1 = self.lstm_1(seq.view(-1,1,self.input_size),self.hidden_1) 
        lstm_out_2 , self.hidden_2 = self.lstm_2(lstm_out_1,self.hidden_2)  
        pred = self.linear(lstm_out_2.view(len(seq),-1))
        return pred

This is the training loop:

def train(model, ddt):

    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(),lr=0.001)

    model.train()
        # Set the number of epochs
    epochs = 50

    for epoch in range(epochs):
        
        # Running each batch separately 
        
        for bat in range(0,len(training_data), data[0].size(dim=0)):#
            #model.hidden_1 = (torch.zeros(1,1,model.hidden_size_1), torch.zeros(1,1,model.hidden_size_1))
            #model.hidden_2 = (torch.zeros(1,1,model.hidden_size_2), torch.zeros(1,1,model.hidden_size_2))
        
            for seq,label in training_data[bat:bat+data[0].size(dim=0)]:
                model.hidden_1 = (torch.zeros(1,1,model.hidden_size_1),                 torch.zeros(1,1,model.hidden_size_1))
                model.hidden_2 = (torch.zeros(1,1,model.hidden_size_2), torch.zeros(1,1,model.hidden_size_2))
        
                seq=seq.to(device)
                label=label.to(device)

                # set the optimization gradient to zero
                optimizer.zero_grad()

                model.zero_grad()
                # initialize the hidden states
                
                # Make predictions on the current sequence
                if ddt: 
                    y_pred = model(seq) + seq # learn derivative?
                else:
                    y_pred = model(seq)

                # Compute the loss
                loss = loss_fn(y_pred, label)         
                # Perform back propogation and gradient descent

                loss.backward(retain_graph=True)

                optimizer.step()

The Model should simply predict the pendulums angle at the next timestep given its current position.

If I try to only initialize the hidden states at the start of each batch (commented out) I get the following error:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Trying to set loss.backward(retain_graph=True) does not solve the issue. I get the following error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.DoubleTensor [15, 40]], which is output 0 of AsStridedBackward0, is at version 7206; expected version 7205 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

I assume this is because the some Tensor inside the model is changed after the gradient update? Any suggestions on how to fix this would be great!

I am a bit confused since I thought the point of LSTMs (and RNNs) in general is to retain the hidden state and use it in the next timestep.. But now I set the hidden state to 0 and it still works pretty good.

Thanks in advance!


Solution

  • You need to re-init the hidden state for each batch. In terms of using the previous hidden state, your batch input should be multiple timesteps. The input to nn.LSTM is of shape (sl, bs, d_in) (or (bs, sl, d_in) if batch_first=True). sl denotes the number of timesteps in the batch. You re-init the hidden state for each batch, but each batch has multiple timesteps. Each timestep in the batch uses the hidden state from the previous timestep. This all happens within nn.LSTM.

    You want something like this:

    class LSTMModel(nn.Module):
        def __init__(self, input_size, d_hidden, num_layers, output_size):
            super().__init__()
            
            self.d_hidden = d_hidden
            self.num_layers = num_layers
            
            # I find `batch_first=True` is easier to work with
            self.lstm = nn.LSTM(input_size, d_hidden, num_layers, batch_first=True)
            self.linear = nn.Linear(d_hidden, output_size)
            
        def forward(self, x, hidden=None):
            
            if hidden is None:
                hidden = self.get_hidden(x)
                
            x, hidden = self.lstm(x, hidden)
            x = self.linear(x)
            return x, hidden
        
        def get_hidden(self, x):
            # note the second axis is batch size, which is `x.shape[0]` for `batch_first=True`
            hidden = (
                    torch.zeros(self.num_layers, x.shape[0], self.d_hidden, device=x.device),
                    torch.zeros(self.num_layers, x.shape[0], self.d_hidden, device=x.device),
                    )
            return hidden 
    

    The reason to return both the output and the hidden state is to re-use the hidden state for inference. Inference would look something like this:

    input_value = ... # some initial input
    hidden = None
    prediction_steps = 5 # number of timesteps to predict 
    preds = []
    
    with torch.no_grad():
        for i in range(prediction_steps):
            # output + hidden are inputs for the next timestep
            input_value, hidden = model(input_value, hidden)
            preds.append(input_value)