pythonpytorchrecurrent-neural-networkbidirectional

Manual Bidirectional torch.nn.RNN Implementation


I'm trying to reimplement the torch.nn.RNN module without C++/CUDA bindings, i.e., using simple tensor operations and associated logic. I have developed the following RNN class and associated testing logic, which can be used to compare output with a reference module instance:

import torch
import torch.nn as nn


class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, bidirectional=False):
        super(RNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.w_ih = [torch.randn(hidden_size, input_size)]
        if bidirectional:
            self.w_ih_reverse = [torch.randn(hidden_size, input_size)]

        for layer in range(num_layers - 1):
            self.w_ih.append(torch.randn(hidden_size, hidden_size))
            if bidirectional:
                self.w_ih_reverse.append(torch.randn(hidden_size, hidden_size))

        self.w_hh = torch.randn(num_layers, hidden_size, hidden_size)
        if bidirectional:
            self.w_hh_reverse = torch.randn(num_layers, hidden_size, hidden_size)

    def forward(self, input, h_0=None):
        if h_0 is None:
            if self.bidirectional:
                h_0 = torch.zeros(2, self.num_layers, input.shape[1], self.hidden_size)
            else:
                h_0 = torch.zeros(1, self.num_layers, input.shape[1], self.hidden_size)

        if self.bidirectional:
            output = torch.zeros(input.shape[0], input.shape[1], 2 * self.hidden_size)
        else:
            output = torch.zeros(input.shape[0], input.shape[1], self.hidden_size)

        for t in range(input.shape[0]):
            print(input.shape, t)
            input_t = input[t]
            if self.bidirectional:
                input_t_reversed = input[-1 - t]

            for layer in range(self.num_layers):
                h_t = torch.tanh(torch.matmul(input_t, self.w_ih[layer].T) + torch.matmul(h_0[0][layer], self.w_hh[layer].T))
                h_0[0][layer] = h_t
                if self.bidirectional:
                    h_t_reverse = torch.tanh(torch.matmul(input_t_reversed, self.w_ih_reverse[layer].T) + torch.matmul(h_0[1][layer], self.w_hh_reverse[layer].T))
                    h_0[1][layer] = h_t_reverse

                input_t = h_t
                if self.bidirectional:
                    # This logic is incorrect for bidirectional RNNs with multiple layers
                    input_t = torch.cat((h_t, h_t_reverse), dim=-1)
                    input_t_reversed = input_t

            output[t, :, :self.hidden_size] = h_t
            if self.bidirectional:
                output[-1 - t, :, self.hidden_size:] = h_t_reverse

        return output


if __name__ == '__main__':
    input_size = 10
    hidden_size = 12
    num_layers = 2
    batch_size = 2
    bidirectional = True
    input = torch.randn(2, batch_size, input_size)
    rnn_val = torch.nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=False, bidirectional=bidirectional, nonlinearity='tanh')
    rnn = RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional)
    for i in range(rnn_val.num_layers):
        rnn.w_ih[i] = rnn_val._parameters['weight_ih_l%d' % i].data
        rnn.w_hh[i] = rnn_val._parameters['weight_hh_l%d' % i].data
        if bidirectional:
            rnn.w_ih_reverse[i] = rnn_val._parameters['weight_ih_l%d_reverse' % i].data
            rnn.w_hh_reverse[i] = rnn_val._parameters['weight_hh_l%d_reverse' % i].data

    output_val, hn_val = rnn_val(input)
    output = rnn(input)
    print(output_val)
    print(output)

My implementation appears to work for vanilla RNNs with an arbitrary number of layers and different batch sizes/sequence lengths, in addition to single-layered bidirectional RNNs, however, it does not produce the correct result for multi-layered bidirectional RNNs.

For sake of simplicity, bias terms are not currently implemented, and only the tanh activation function is supported. I have narrowed the logic error down to the line input_t = torch.cat((h_t, h_t_reverse), dim=-1), as the first output sequence is incorrect.

It would be greatly appreciated if someone could point me in the correct direction, and let me know what the problem is!


Solution

  • There is two possible approaches to forward:

    So while first one good for RNN with one direction, it's not working properly for bidirectional multi-layered RNN. To illustrate, let's take 2 layers (same case in code) -- for computation output[0] input from previous layer is needed, and it's concatenation of:

    1. hidden vector from normal pass of length 1 (because it's right at start of sequence)
    2. and hidden vector from reverse pass of length seq_length (need to step through whole sequence, from end to start, to get it)

    So when step through layers is taken first, it takes only one step in time (length of pass equal 1), and therefore output[0] takes garbage as input, because second part is not correct (there was no "whole pass from end to start").

    To overcome this issue, I'd suggest to rewrite cycles in forward from:

    for t in range(input.shape[0]):
        ...
        for layer in range(self.num_layers):
            ...
    

    to something like:

    for layer in range(self.num_layers):
        ...
        for t in range(input.shape[0]):
            ...
    

    As alternative, leave forward for computing in other normal cases, but for bidirectional multilayer write another function forward_bidir, and write this cycles there.

    It also worth noting, that w_ih[k] for k > 0 in bidirectional case have shape (hidden_size, 2 * hidden_size), as stated in pytorch documentation on RNN. Also, function torch.allclose should serve better for testing outputs than prints.

    For code fixes check gist, no optimisations were taken, main aim to preserve original idea, seems to work for all configuration listed above (one-directional, multi-layered, bi-directional).