I'm trying to implement a DQN. As a warm up I want to solve CartPole-v0 with a MLP consisting of two hidden layers along with input and output layers. The input is a 4 element array [cart position, cart velocity, pole angle, pole angular velocity] and output is an action value for each action (left or right). I am not exactly implementing a DQN from the "Playing Atari with DRL" paper (no frame stacking for inputs etc). I also made a few non standard choices like putting done
and the target network prediction of action value in the experience replay, but those choices shouldn't affect learning.
In any case I'm having a lot of trouble getting the thing to work. No matter how long I train the agent it keeps predicting a higher value for one action over another, for example Q(s, Right)> Q(s, Left) for all states s. Below is my learning code, my network definition, and some results I get from training
class DQN:
def __init__(self, env, steps_per_episode=200):
self.env = env
self.agent_network = MlpPolicy(self.env)
self.target_network = MlpPolicy(self.env)
self.target_network.load_state_dict(self.agent_network.state_dict())
self.target_network.eval()
self.optimizer = torch.optim.RMSprop(
self.agent_network.parameters(), lr=0.005, momentum=0.95
)
self.replay_memory = ReplayMemory()
self.gamma = 0.99
self.steps_per_episode = steps_per_episode
self.random_policy_stop = 1000
self.start_learning_time = 1000
self.batch_size = 32
def learn(self, episodes):
time = 0
for episode in tqdm(range(episodes)):
state = self.env.reset()
for step in range(self.steps_per_episode):
if time < self.random_policy_stop:
action = self.env.action_space.sample()
else:
action = select_action(self.env, time, state, self.agent_network)
new_state, reward, done, _ = self.env.step(action)
target_value_pred = predict_target_value(
new_state, reward, done, self.target_network, self.gamma
)
experience = Experience(
state, action, reward, new_state, done, target_value_pred
)
self.replay_memory.append(experience)
if time > self.start_learning_time: # learning step
experience_batch = self.replay_memory.sample(self.batch_size)
target_preds = extract_value_predictions(experience_batch)
agent_preds = agent_batch_preds(
experience_batch, self.agent_network
)
loss = torch.square(agent_preds - target_preds).sum()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if time % 1_000 == 0: # how frequently to update target net
self.target_network.load_state_dict(self.agent_network.state_dict())
self.target_network.eval()
state = new_state
time += 1
if done:
break
def agent_batch_preds(experience_batch: list, agent_network: MlpPolicy):
"""
Calculate the agent action value estimates using the old states and the
actual actions that the agent took at that step.
"""
old_states = extract_old_states(experience_batch)
actions = extract_actions(experience_batch)
agent_preds = agent_network(old_states)
experienced_action_values = agent_preds.index_select(1, actions).diag()
return experienced_action_values
def extract_actions(experience_batch: list) -> list:
"""
Extract the list of actions from experience replay batch and torchify
"""
actions = [exp.action for exp in experience_batch]
actions = torch.tensor(actions)
return actions
class MlpPolicy(nn.Module):
"""
This class implements the MLP which will be used as the Q network. I only
intend to solve classic control problems with this.
"""
def __init__(self, env):
super(MlpPolicy, self).__init__()
self.env = env
self.input_dim = self.env.observation_space.shape[0]
self.output_dim = self.env.action_space.n
self.fc1 = nn.Linear(self.input_dim, 32)
self.fc2 = nn.Linear(32, 128)
self.fc3 = nn.Linear(128, 32)
self.fc4 = nn.Linear(32, self.output_dim)
def forward(self, x):
if type(x) != torch.Tensor:
x = torch.tensor(x).float()
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = self.fc4(x)
return x
Learning results:
Here I'm seeing one action always valued over the others (Q(right, s) > Q(left, s)). It's also clear that the network is predicting the same action values for every state.
Does anyone have an idea about what's going on? I've done a lot of debugging and careful reading of the original papers (also thought about "normalizing" the observation space even though the velocities can be infinite) and could be missing something obvious at this point. I can include more code for the helper functions if that would be useful.
There was nothing wrong with the network definition. It turns out the learning rate was too high and reducing it 0.00025 (as in the original Nature paper introducing the DQN) led to an agent which can solve CartPole-v0.
That said, the learning algorithm was incorrect. In particular I was using the wrong target action-value predictions. Note the algorithm laid out above does not use the most recent version of the target network to make predictions. This leads to poor results as training progresses because the agent is learning based on stale target data. The way to fix this is to just put (s, a, r, s', done)
into the replay memory and then make target predictions using the most up to date version of the target network when sampling a mini batch. See the code below for an updated learning loop.
def learn(self, episodes):
time = 0
for episode in tqdm(range(episodes)):
state = self.env.reset()
for step in range(self.steps_per_episode):
if time < self.random_policy_stop:
action = self.env.action_space.sample()
else:
action = select_action(self.env, time, state, self.agent_network)
new_state, reward, done, _ = self.env.step(action)
experience = Experience(state, action, reward, new_state, done)
self.replay_memory.append(experience)
if time > self.start_learning_time: # learning step.
experience_batch = self.replay_memory.sample(self.batch_size)
target_preds = target_batch_preds(
experience_batch, self.target_network, self.gamma
)
agent_preds = agent_batch_preds(
experience_batch, self.agent_network
)
loss = torch.square(agent_preds - target_preds).sum()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if time % 1_000 == 0: # how frequently to update target net
self.target_network.load_state_dict(self.agent_network.state_dict())
self.target_network.eval()
state = new_state
time += 1
if done:
break