pythonpytorchreinforcement-learning

How TorchRL deals with multiple trajectories in the same batch?


I am trying to understand how the DQN algorithm with RNNs works in PyTorch's RL API through this tutorial. However, the way some of the classes handle episodes and batches during training are unclear to me.

I noticed that SyncDataCollector stitches episodes together into the same TensorDict until it reaches the desired number of frames per batch. I am wondering if this means that during training, the DQNLoss treats all these stitched episodes as a single trajectory or if it knows to separate them back into individual episodes. If they are treated as one, wouldn't this mix different trajectories and affect learning? Additionally, does the replay buffer randomly sample states and actions without tracking which episode they belong to, potentially mixing unrelated data?


Solution

  • Collectors may return batches that combine splits of different trajectories.

    If you feed them directly to your loss, it should be OK because losses and advantage functions deal with stacks of trajectories naturally by detecting the done/terminated/truncated markers - there is no cross contamination.

    With DQNLoss, you will write the data in a replay buffer and then either sample single transitions or entire (slices of) trajectories using a SliceSampler In this case, you can also be sure that there is no cross-contamination.

    Resources:

    Finally, LLM collectors have the ability to yield / write in the buffer full trajectories only. This is a nice feature and we could bring it to regular collectors, the only thing is that these classes have become massive over the years and refacto