pythonpytorchlstm

batch_first in PyTorch LSTM


I still don't understand about the batch_first in PyTorch LSTM. I tried the code that someone has referred to me, and it works on my train data when batch_first = False, it produces the same output for Official LSTM and Manual LSTM. However, when I change to batch_first = True, it does not produce the same value anymore, while I need to change the batch_first to True, because my dataset shape is tensor of (Batch, Sequences, Input size). Which part of the Manual LSTM needs to be changed to produces the same output as the Official LSTM produces when batch_first = True? Here is the code snippet:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

train_x = torch.tensor([[[0.14285755], [0], [0.04761982], [0.04761982], [0.04761982],
  [0.04761982], [0.04761982], [0.09523869], [0.09523869], [0.09523869], 
  [0.09523869], [0.09523869], [0.04761982], [0.04761982], [0.04761982],
  [0.04761982], [0.09523869], [0.        ], [0.        ], [0.        ],
  [0.        ], [0.09523869], [0.09523869], [0.09523869], [0.09523869],
  [0.09523869], [0.09523869], [0.09523869],[0.14285755], [0.14285755]]], 
  requires_grad=True)

seed = 23
torch.manual_seed(seed)
np.random.seed(seed)

pytorch_lstm = torch.nn.LSTM(1, 1, bidirectional=False, num_layers=1, batch_first=True)
weights = torch.randn(pytorch_lstm.weight_ih_l0.shape,dtype = torch.float)
pytorch_lstm.weight_ih_l0 = torch.nn.Parameter(weights)
# Set bias to Zero
pytorch_lstm.bias_ih_l0 = torch.nn.Parameter(torch.zeros(pytorch_lstm.bias_ih_l0.shape))
pytorch_lstm.weight_hh_l0 = torch.nn.Parameter(torch.ones(pytorch_lstm.weight_hh_l0.shape))
# Set bias to Zero
pytorch_lstm.bias_hh_l0 = torch.nn.Parameter(torch.zeros(pytorch_lstm.bias_ih_l0.shape))
pytorch_lstm_out = pytorch_lstm(train_x)

batch_size=1

# Manual Calculation
W_ii, W_if, W_ig, W_io = pytorch_lstm.weight_ih_l0.split(1, dim=0)
b_ii, b_if, b_ig, b_io = pytorch_lstm.bias_ih_l0.split(1, dim=0)

W_hi, W_hf, W_hg, W_ho = pytorch_lstm.weight_hh_l0.split(1, dim=0)
b_hi, b_hf, b_hg, b_ho = pytorch_lstm.bias_hh_l0.split(1, dim=0)

prev_h = torch.zeros((batchsize,1))
prev_c = torch.zeros((batchsize,1))

i_t = torch.sigmoid(F.linear(train_x, W_ii, b_ii) + F.linear(prev_h, W_hi, b_hi))
f_t = torch.sigmoid(F.linear(train_x, W_if, b_if) + F.linear(prev_h, W_hf, b_hf))
g_t = torch.tanh(F.linear(train_x, W_ig, b_ig) + F.linear(prev_h, W_hg, b_hg))
o_t = torch.sigmoid(F.linear(train_x, W_io, b_io) + F.linear(prev_h, W_ho, b_ho))
c_t = f_t * prev_c + i_t * g_t
h_t = o_t * torch.tanh(c_t)

print('nn.LSTM output {}, manual output {}'.format(pytorch_lstm_out[0], h_t))
print('nn.LSTM hidden {}, manual hidden {}'.format(pytorch_lstm_out[1][0], h_t))
print('nn.LSTM state {}, manual state {}'.format(pytorch_lstm_out[1][1], c_t))

Solution

  • You have to iterate through each sequence element at a time and take the computed hidden and cell states as input in the next time step...

    h_t = torch.zeros((batch_size,1))
    c_t = torch.zeros((batch_size,1))
    
    hidden_seq = []
    
    for t in range(30):
      x_t = train_x[:, t, :]
      i_t = torch.sigmoid(F.linear(x_t, W_ii, b_ii) + F.linear(h_t, W_hi, b_hi))
      f_t = torch.sigmoid(F.linear(x_t, W_if, b_if) + F.linear(h_t, W_hf, b_hf))
      g_t = torch.tanh(F.linear(x_t, W_ig, b_ig) + F.linear(h_t, W_hg, b_hg))
      o_t = torch.sigmoid(F.linear(x_t, W_io, b_io) + F.linear(h_t, W_ho, b_ho))
      c_t = f_t * c_t + i_t * g_t
      h_t = o_t * torch.tanh(c_t)
      hidden_seq.append(h_t.unsqueeze(0))
    
    hidden_seq = torch.cat(hidden_seq, dim=0)
    hidden_seq = hidden_seq.transpose(0, 1).contiguous()
    
    print('nn.LSTM output {}, manual output {}'.format(pytorch_lstm_out[0], hidden_seq))