I'm trying to use torch.autograd to train a simple recurrent neural network that predicts the next character in a sequence of characters that represent songs in an ABC notation.
The model looks like this:
model = keras.Sequential([
keras.layers.Input(shape=(SEQ_LENGTH,), batch_size=batch_size),
keras.layers.Embedding(len(vocabulary), 256),
keras.layers.LSTM(1024, return_sequences=True, stateful=stateful),
keras.layers.Dense(len(vocabulary))
])
The training process looks like this:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)
for i in range(1000):
inputs, targets = random_inputs_and_targets(vectorized_songs, seq_length=SEQ_LENGTH, batch_size=BATCH_SIZE)
predictions = model(inputs)
loss = loss_fn(predictions.permute(0, 2, 1), torch.from_numpy(targets).long())
loss.backward()
optimizer.step()
optimizer.zero_grad()
I then save the model parameters and load them into the similar model, but this time the model is stateful and has batch size 1
:
torch.save(model.state_dict(), os.path.join(cwd, "model.pt"))
trained_model = build_model(1, True)
trained_model.load_state_dict(torch.load(os.path.join(cwd, "model.pt")))
trained_model.eval()
Then, I use the loaded model to predict a string of characters that I expect to look like a song in the ABC notation:
input_eval = [char_to_index[s] for s in start_string]
input_eval = torch.unsqueeze(torch.tensor(input_eval), 0)
text_generated = []
for i in range(generation_length):
predictions = torch.squeeze(model(input_eval), 0)
predicted_index = torch.multinomial(softmax(predictions, dim=0), 1, replacement=True)[-1, 0]
input_eval = torch.unsqueeze(torch.unsqueeze(predicted_index, 0), 0)
text_generated.append(index_to_char[predicted_index.item()])
return start_string + ''.join(text_generated)
The full code is here.
During the 1000 training epochs, the loss function value goes down from around 4.42
to 0.78
, as expected.
But when I then try to use the "trained" model to generate a song, the result looks like a random string: XwQ5>ab>6q6S(z']!<hxaG4..M= (=ERp/xJmS|qIh_CzbM0D-N 6Yc=Ei[tcodBsEKfW<WZ5Jb("u1rrGLcFIk"PVk.'FEII:(qu7.nFbw^3/RY2LyrW
. An example of the full result can be seen here.
How do I even start debugging what is going wrong? Previously I built a simple non-recurrent classifier using torch.autograd
, its outputs were only 90% accurate, but this was still much better than when I try to build an RNN. Can it be that the hidden state that the RNN needs to predict the next character is lost somewhere during training or actual prediction?
Any suggestions are welcome, since I'm getting stuck.
Managed to find the issue. Going to post an answer here for reference.
The problem was in the inference. The model outputs log softmax, so to get the probability distribution, I need to find the exponential of the model output, i.e., predictions.exp()
.
However, I was incorrectly calling softmax on the output (torch.nn.functional.softmax(predictions, dim=0)
).