I'm working on a LSTM RecurrentPPO that's need a behavioural cloning implementation.
The Imitation library provided with Stable Baselines 3 (see here : https://imitation.readthedocs.io/en/latest/) does not seem made for SB3-contrib's RecurrentPPO.
I found this method that could be adapted for RecurrentPPO : https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pretraining.ipynb
I guess this part of code have to be modified in order to consider lstm_states and episode_starts but I don't know how to implement it.
def pretrain_agent(
student,
batch_size=64,
epochs=1000,
scheduler_gamma=0.7,
learning_rate=1.0,
log_interval=100,
no_cuda=True,
seed=1,
test_batch_size=64,
):
use_cuda = not no_cuda and th.cuda.is_available()
th.manual_seed(seed)
device = th.device("cuda" if use_cuda else "cpu")
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
if isinstance(env.action_space, gym.spaces.Box):
criterion = nn.MSELoss()
else:
criterion = nn.CrossEntropyLoss()
# Extract initial policy
model = student.policy.to(device)
def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
if isinstance(env.action_space, gym.spaces.Box):
# A2C/PPO policy outputs actions, values, log_prob
# SAC/TD3 policy outputs actions only
if isinstance(student, (A2C, PPO)):
action, _, _ = model(data)
else:
# SAC/TD3:
action = model(data)
action_prediction = action.double()
else:
# Retrieve the logits for A2C/PPO when using discrete actions
dist = model.get_distribution(data)
action_prediction = dist.distribution.logits
target = target.long()
loss = criterion(action_prediction, target)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
epoch,
batch_idx * len(data),
len(train_loader.dataset),
100.0 * batch_idx / len(train_loader),
loss.item(),
)
)
def test(model, device, test_loader):
model.eval()
test_loss = 0
with th.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
if isinstance(env.action_space, gym.spaces.Box):
# A2C/PPO policy outputs actions, values, log_prob
# SAC/TD3 policy outputs actions only
if isinstance(student, (A2C, PPO)):
action, _, _ = model(data)
else:
# SAC/TD3:
action = model(data)
action_prediction = action.double()
else:
# Retrieve the logits for A2C/PPO when using discrete actions
dist = model.get_distribution(data)
action_prediction = dist.distribution.logits
target = target.long()
test_loss = criterion(action_prediction, target)
test_loss /= len(test_loader.dataset)
print(f"Test set: Average loss: {test_loss:.4f}")
# Here, we use PyTorch `DataLoader` to our load previously created `ExpertDataset` for training
# and testing
train_loader = th.utils.data.DataLoader(
dataset=train_expert_dataset, batch_size=batch_size, shuffle=True, **kwargs
)
test_loader = th.utils.data.DataLoader(
dataset=test_expert_dataset,
batch_size=test_batch_size,
shuffle=True,
**kwargs,
)
# Define an Optimizer and a learning rate schedule.
optimizer = optim.Adadelta(model.parameters(), lr=learning_rate)
scheduler = StepLR(optimizer, step_size=1, gamma=scheduler_gamma)
# Now we are finally ready to train the policy model.
for epoch in range(1, epochs + 1):
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
scheduler.step()
# Implant the trained policy network back into the RL student agent
a2c_student.policy = model
Does anyone have a solution?
Just stumbled upon this problem as well.
Traceback (most recent call last):
File "my_imitate.py", line 49, in <module>
bc_trainer.train(n_epochs=1)
File "python3.8/site-packages/imitation/algorithms/bc.py", line 470, in train
training_metrics = self.loss_calculator(self.policy, obs, acts)
File "python3.8/site-packages/imitation/algorithms/bc.py", line 119, in __call__
_, log_prob, entropy = policy.evaluate_actions(obs, acts)
TypeError: evaluate_actions() missing 2 required positional arguments: 'lstm_states' and 'episode_starts'
The problem is obviously that evaluate_actions
in RecurrentActorCriticPolicy
has a different signature for evaluate_actions
which needs the lstm_states
and episode_starts
as well.
My first thought was that this means that also during rollout collection this information needs to be stored (which I thought, it would, but it does not). And the solution would be to store the missing infos during rollout collection and handle them during BC if they are there and compatible with the policy at hand.
But actually it is unclear what the expert state
is, when the expert policy from the rollout collection is not recurrent itself (but e.g. a near-optimal search algorithm). Thus for recurrent policies the BC algorithm should train using whole trajectories from begin to end and passing the lstm_step
in between timesteps.