pythondeep-learningpytorchreinforcement-learning

DQN fails to learn good policy for Atari Pong


I'm trying to implement the findings from this DeepMind DQN paper (2015) from scratch in PyTorch using the Atari Pong environment.

I've tested my Deep Q-Network on a simple test environment, where episodes last only 5 time steps and the maximum achievable reward is 4.1. My DQN successfully achieves this reward about 70% of the time I train it, so I'm confident that the implementation is functionally correct.

However, when I apply the exact same DQN algorithm to Atari Pong, the agent fails to learn or show any meaningful improvement—even after training for nearly 3 million timesteps.

I've attached graphs of training metrics here:

Performance over time

Weights and Gradients

More weights and gradients

The following graphs are from my most recent training session and it seems like the network cannot learn since the max reward, the average reward, and the Q values only decrease. The most interesting part is that the loss is nearly zero but the network is still not learning anything close to a good policy.

The only hypothesis I have is training instability. I initially observed vanishing gradients, so I switched to torch.nn.SmoothL1Loss and added batch normalization (not shown in code), which helped the gradients but did not improve performance.

My Setup Full code on GitHub: https://github.com/rohanpatel01/Reinforcement_Learning/tree/try_testenv_dqn_but_not_work/Atari%20DQN%20Project

I removed some comments and logging so that the code here would be easier to read.

Model (DQN.py)

class NatureQN(nn.Module):

    def linear_decay(self, current_step):
        if current_step >= self.config.lr_n_steps:
            return self.config.lr_end / self.config.lr_begin
        return 1.0 - (current_step / self.config.lr_n_steps) * (1 - self.config.lr_end / self.config.lr_begin)


    # Following Model Architecture from mnih2015human
    def __init__(self, env, config, device):
        super(NatureQN, self).__init__()
        self.env = env
        self.config = config
        self.device = device

        self.conv1 = nn.Conv2d(in_channels=env.observation_space.shape[0], out_channels=32, kernel_size=8, stride=4, dtype=torch.float32)
        self.bn_conv1 = nn.BatchNorm2d(32)

        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, dtype=torch.float32)
        self.bn_conv2 = nn.BatchNorm2d(64)

        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, dtype=torch.float32)
        self.bn_conv3 = nn.BatchNorm2d(64)

        self.fc1 = nn.Linear(in_features=3136, out_features=512, dtype=torch.float32)
        self.bn_fc1 = nn.BatchNorm1d(512)

        self.fc2    = nn.Linear(in_features=512, out_features=env.action_space.n, dtype=torch.float32)

        self.ReLU = nn.ReLU()

        self.optimizer = optim.RMSprop(self.parameters(), lr=self.config.lr_begin, alpha=self.config.squared_gradient_momentum, eps=self.config.rms_eps)  # , alpha=self.config.squared_gradient_momentum, eps=self.config.rms_eps
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.linear_decay)
        self.criterion = nn.SmoothL1Loss()

    def forward(self, x):
        x = self.conv1(x)
        x = self.ReLU(x)

        x = self.conv2(x)
        x = self.ReLU(x)

        x = self.conv3(x)
        x = self.ReLU(x)

        if len(x.shape) == 3:  # single input [height, width, num_channels]
            x = torch.flatten(x)
        elif len(x.shape) == 4:  # batch input [batch_size, height, width, num_channels]
            x = torch.flatten(x, start_dim=1)

        x = self.fc1(x)
        x = self.ReLU(x)
        x = self.fc2(x)

        return x

Main()

def main():
    config = AtariDQNConfig()
    env = gym.make("ALE/Pong-v5", frameskip=1, repeat_action_probability=0)
    env = ReducedActionSet(env, allowed_actions=[0, 2, 3])
    env = AtariPreprocessing(
            env,
            noop_max=30, frame_skip=4, terminal_on_life_loss=False, # changed noop_max to 30 from 0
            screen_size=84, grayscale_obs=True, grayscale_newaxis=False,
            scale_obs=False
        )
    env = FrameStackObservation(env, stack_size=4)

    model = DQN(env, config, device)
    model.train()

Hyperparameters (AtariDQNConfig.py)

I've also tried tweaking:

Training Loop (in Q_learning.py)

def train(self):

    epsilon_scheduler = EpsilonScheduler(self.config.begin_epsilon, self.config.end_epsilon,
                                         self.config.max_time_steps_update_epsilon)

    time_last_saved = self.t

    while self.t <= self.config.nsteps_train:

        print("Time: ", self.t, " Epsilon: ", epsilon_scheduler.get_epsilon(self.t - self.config.learning_delay), " Learning Rate: ", self.approx_network.optimizer.param_groups[0]['lr'])

        if (self.t - time_last_saved) >= self.config.saving_freq:
            self.save_snapshop(self.t, self.num_episodes, self.total_reward_so_far, self.replay_buffer) # just for development now we're not gonna save snapshots
            time_last_saved = self.t

        state, info = self.env.reset()
        state = torch.from_numpy(state)
        state = self.process_state(state)
        state = state.to(self.device)

        total_reward_for_episode = 0
        start_time = time.time()

        while True:
            with torch.no_grad():

                action = self.sample_action(self.env, state, epsilon_scheduler.get_epsilon(self.t - self.config.learning_delay), self.t, "approx")

                next_state, reward, terminated, truncated, info = self.env.step(action)
                next_state = torch.from_numpy(next_state)
                next_state = self.process_state(next_state).to(self.device)

                # convert state and next_state to uint8 before placing in replay buffer
                experience_tuple = ((state*self.config.high).to(torch.uint8).to('cpu'), action, reward, (next_state*self.config.high).to(torch.uint8).to('cpu'), terminated)
                self.replay_buffer.store(experience_tuple)

            state = next_state

            if (self.t > self.config.learning_start) and (self.t % self.config.learning_freq == 0):
                self.train_on_minibatch(self.replay_buffer.sample_minibatch(), self.t)

            
            self.monitor_performance(state, reward, monitor_end_of_episode=False, timestep=self.t)    # used to have env as param
            
            self.total_reward_so_far += reward
            total_reward_for_episode += reward

            self.t += 1
            if (self.t > self.config.learning_start) and (self.t % self.config.target_weight_update_freq == 0):
                self.set_target_weights()

            if terminated:
                # monitor avg reward per episode and max_reward per episode (at end of episode)
                self.num_episodes += 1
                self.monitor_performance(state, reward, monitor_end_of_episode=True, timestep=self.t, context = (self.total_reward_so_far, self.num_episodes, total_reward_for_episode))
                break

Update function (train_on_minibatch() in Linear.py)

def train_on_minibatch(self, minibatch, timestep):
    states, actions, rewards, next_states, dones = minibatch

    states = self.process_state(states)
    next_states = self.process_state(next_states)

    states = states.to(self.device).to(torch.float32)
    actions = actions.to(self.device)
    rewards = rewards.to(self.device)
    next_states = next_states.to(self.device).to(torch.float32)
    dones = dones.to(self.device)

    q_vals = self.approx_network(states)
    q_chosen = q_vals.gather(1, actions.unsqueeze(1)).squeeze(1)

    # New version added
    with torch.no_grad():
        q_next_all = self.target_network(next_states)  # [batch_size, num_actions]
        best_actions = torch.argmax(q_next_all, dim=1)  # [batch_size]
        next_q_values = q_next_all.gather(1, best_actions.unsqueeze(1)).squeeze(1)  # [batch_size]

        # Q_target = r if done else r + gamma * max_a Q_target(s', a)
        target = torch.where(
            dones,
            rewards,
            rewards + self.config.gamma * next_q_values 
        )

    self.approx_network.optimizer.zero_grad()
    self.target_network.optimizer.zero_grad()

    loss = self.approx_network.criterion(q_chosen, target)
    writer.add_scalar("Loss/train", loss.item(), timestep)
    writer.add_scalar("Reward/train", torch.mean(rewards), timestep)
    loss.backward()
    
    if self.config.grad_clip:
        torch.nn.utils.clip_grad_norm_(self.approx_network.parameters(), self.config.clip_val)

    self.approx_network.optimizer.step()
    self.approx_network.scheduler.step()

I want to stick to the paper as closely as possible so I did not implement double DQN or prioritized experience replay like others may suggest.


Solution

  • Your input tensor is laid out in H × W × C (84 × 84 × 4), but you build conv1 with

    in_channels = env.observation_space.shape[0]      # 84
    

    and feed the raw tensor straight into nn.Conv2d, which expects N × C × H × W.
    The result: the network thinks each row of the frame is a separate channel, while the four stacked frames are treated as width pixels. It happily trains on this scrambled view, so you get “near-zero loss” (the target network quickly matches the main net) but no learning signal for Pong.

    Fix the channel order and the first layer:

    # preprocessing — put channels first and scale to [0,1]
    obs = torch.from_numpy(obs).float().div_(255)     # (84,84,4) → float
    obs = obs.permute(2, 0, 1)                        # (C=4,H=84,W=84)
    
    # model
    self.conv1 = nn.Conv2d(in_channels=4, out_channels=32,
                           kernel_size=8, stride=4)
    

    (You can also set grayscale_newaxis=True and let gymnasium.wrappers.FrameStack give you channel-first directly.)

    A few side notes:

    After the channel fix, the training curve should start rising instead of flatlining.