I'm trying to implement the findings from this DeepMind DQN paper (2015) from scratch in PyTorch using the Atari Pong environment.
I've tested my Deep Q-Network on a simple test environment, where episodes last only 5 time steps and the maximum achievable reward is 4.1. My DQN successfully achieves this reward about 70% of the time I train it, so I'm confident that the implementation is functionally correct.
However, when I apply the exact same DQN algorithm to Atari Pong, the agent fails to learn or show any meaningful improvement—even after training for nearly 3 million timesteps.
I've attached graphs of training metrics here:
The following graphs are from my most recent training session and it seems like the network cannot learn since the max reward, the average reward, and the Q values only decrease. The most interesting part is that the loss is nearly zero but the network is still not learning anything close to a good policy.
The only hypothesis I have is training instability. I initially observed vanishing gradients, so I switched to torch.nn.SmoothL1Loss and added batch normalization (not shown in code), which helped the gradients but did not improve performance.
My Setup Full code on GitHub: https://github.com/rohanpatel01/Reinforcement_Learning/tree/try_testenv_dqn_but_not_work/Atari%20DQN%20Project
I removed some comments and logging so that the code here would be easier to read.
Model (DQN.py)
class NatureQN(nn.Module):
def linear_decay(self, current_step):
if current_step >= self.config.lr_n_steps:
return self.config.lr_end / self.config.lr_begin
return 1.0 - (current_step / self.config.lr_n_steps) * (1 - self.config.lr_end / self.config.lr_begin)
# Following Model Architecture from mnih2015human
def __init__(self, env, config, device):
super(NatureQN, self).__init__()
self.env = env
self.config = config
self.device = device
self.conv1 = nn.Conv2d(in_channels=env.observation_space.shape[0], out_channels=32, kernel_size=8, stride=4, dtype=torch.float32)
self.bn_conv1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, dtype=torch.float32)
self.bn_conv2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, dtype=torch.float32)
self.bn_conv3 = nn.BatchNorm2d(64)
self.fc1 = nn.Linear(in_features=3136, out_features=512, dtype=torch.float32)
self.bn_fc1 = nn.BatchNorm1d(512)
self.fc2 = nn.Linear(in_features=512, out_features=env.action_space.n, dtype=torch.float32)
self.ReLU = nn.ReLU()
self.optimizer = optim.RMSprop(self.parameters(), lr=self.config.lr_begin, alpha=self.config.squared_gradient_momentum, eps=self.config.rms_eps) # , alpha=self.config.squared_gradient_momentum, eps=self.config.rms_eps
self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.linear_decay)
self.criterion = nn.SmoothL1Loss()
def forward(self, x):
x = self.conv1(x)
x = self.ReLU(x)
x = self.conv2(x)
x = self.ReLU(x)
x = self.conv3(x)
x = self.ReLU(x)
if len(x.shape) == 3: # single input [height, width, num_channels]
x = torch.flatten(x)
elif len(x.shape) == 4: # batch input [batch_size, height, width, num_channels]
x = torch.flatten(x, start_dim=1)
x = self.fc1(x)
x = self.ReLU(x)
x = self.fc2(x)
return x
Main()
def main():
config = AtariDQNConfig()
env = gym.make("ALE/Pong-v5", frameskip=1, repeat_action_probability=0)
env = ReducedActionSet(env, allowed_actions=[0, 2, 3])
env = AtariPreprocessing(
env,
noop_max=30, frame_skip=4, terminal_on_life_loss=False, # changed noop_max to 30 from 0
screen_size=84, grayscale_obs=True, grayscale_newaxis=False,
scale_obs=False
)
env = FrameStackObservation(env, stack_size=4)
model = DQN(env, config, device)
model.train()
Hyperparameters (AtariDQNConfig.py)
I've also tried tweaking:
Training Loop (in Q_learning.py)
def train(self):
epsilon_scheduler = EpsilonScheduler(self.config.begin_epsilon, self.config.end_epsilon,
self.config.max_time_steps_update_epsilon)
time_last_saved = self.t
while self.t <= self.config.nsteps_train:
print("Time: ", self.t, " Epsilon: ", epsilon_scheduler.get_epsilon(self.t - self.config.learning_delay), " Learning Rate: ", self.approx_network.optimizer.param_groups[0]['lr'])
if (self.t - time_last_saved) >= self.config.saving_freq:
self.save_snapshop(self.t, self.num_episodes, self.total_reward_so_far, self.replay_buffer) # just for development now we're not gonna save snapshots
time_last_saved = self.t
state, info = self.env.reset()
state = torch.from_numpy(state)
state = self.process_state(state)
state = state.to(self.device)
total_reward_for_episode = 0
start_time = time.time()
while True:
with torch.no_grad():
action = self.sample_action(self.env, state, epsilon_scheduler.get_epsilon(self.t - self.config.learning_delay), self.t, "approx")
next_state, reward, terminated, truncated, info = self.env.step(action)
next_state = torch.from_numpy(next_state)
next_state = self.process_state(next_state).to(self.device)
# convert state and next_state to uint8 before placing in replay buffer
experience_tuple = ((state*self.config.high).to(torch.uint8).to('cpu'), action, reward, (next_state*self.config.high).to(torch.uint8).to('cpu'), terminated)
self.replay_buffer.store(experience_tuple)
state = next_state
if (self.t > self.config.learning_start) and (self.t % self.config.learning_freq == 0):
self.train_on_minibatch(self.replay_buffer.sample_minibatch(), self.t)
self.monitor_performance(state, reward, monitor_end_of_episode=False, timestep=self.t) # used to have env as param
self.total_reward_so_far += reward
total_reward_for_episode += reward
self.t += 1
if (self.t > self.config.learning_start) and (self.t % self.config.target_weight_update_freq == 0):
self.set_target_weights()
if terminated:
# monitor avg reward per episode and max_reward per episode (at end of episode)
self.num_episodes += 1
self.monitor_performance(state, reward, monitor_end_of_episode=True, timestep=self.t, context = (self.total_reward_so_far, self.num_episodes, total_reward_for_episode))
break
Update function (train_on_minibatch() in Linear.py)
def train_on_minibatch(self, minibatch, timestep):
states, actions, rewards, next_states, dones = minibatch
states = self.process_state(states)
next_states = self.process_state(next_states)
states = states.to(self.device).to(torch.float32)
actions = actions.to(self.device)
rewards = rewards.to(self.device)
next_states = next_states.to(self.device).to(torch.float32)
dones = dones.to(self.device)
q_vals = self.approx_network(states)
q_chosen = q_vals.gather(1, actions.unsqueeze(1)).squeeze(1)
# New version added
with torch.no_grad():
q_next_all = self.target_network(next_states) # [batch_size, num_actions]
best_actions = torch.argmax(q_next_all, dim=1) # [batch_size]
next_q_values = q_next_all.gather(1, best_actions.unsqueeze(1)).squeeze(1) # [batch_size]
# Q_target = r if done else r + gamma * max_a Q_target(s', a)
target = torch.where(
dones,
rewards,
rewards + self.config.gamma * next_q_values
)
self.approx_network.optimizer.zero_grad()
self.target_network.optimizer.zero_grad()
loss = self.approx_network.criterion(q_chosen, target)
writer.add_scalar("Loss/train", loss.item(), timestep)
writer.add_scalar("Reward/train", torch.mean(rewards), timestep)
loss.backward()
if self.config.grad_clip:
torch.nn.utils.clip_grad_norm_(self.approx_network.parameters(), self.config.clip_val)
self.approx_network.optimizer.step()
self.approx_network.scheduler.step()
I want to stick to the paper as closely as possible so I did not implement double DQN or prioritized experience replay like others may suggest.
Your input tensor is laid out in H × W × C (84 × 84 × 4), but you build conv1
with
in_channels = env.observation_space.shape[0] # 84
and feed the raw tensor straight into nn.Conv2d
, which expects N × C × H × W.
The result: the network thinks each row of the frame is a separate channel, while the four stacked frames are treated as width pixels. It happily trains on this scrambled view, so you get “near-zero loss” (the target network quickly matches the main net) but no learning signal for Pong.
Fix the channel order and the first layer:
# preprocessing — put channels first and scale to [0,1]
obs = torch.from_numpy(obs).float().div_(255) # (84,84,4) → float
obs = obs.permute(2, 0, 1) # (C=4,H=84,W=84)
# model
self.conv1 = nn.Conv2d(in_channels=4, out_channels=32,
kernel_size=8, stride=4)
(You can also set grayscale_newaxis=True
and let gymnasium.wrappers.FrameStack
give you channel-first directly.)
A few side notes:
Drop BatchNorm – DQN’s replay buffer violates the i.i.d. assumption; BN usually hurts.
Keep reward clipping to +-1 and Huber (SmoothL1) loss, as in the paper.
On Pong you should see the average return climb after ~1 M steps and hit ≥15 after 8–10 M with vanilla DQN.
After the channel fix, the training curve should start rising instead of flatlining.