pythonpytorchreinforcement-learningstable-baselinesstablebaseline3

Training a Custom Feature Extractor in Stable Baselines3 Starting from Pre-trained Weights?


I am using the following custom feature extractor for my StableBaselines3 model:

import torch.nn as nn
from stable_baselines3 import PPO

class Encoder(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim=2):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, embedding_dim),
            nn.ReLU()
        )
        self.regressor = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.ReLU(),
        )
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.regressor(x)
        return x
    
model = Encoder(input_dim, embedding_dim, hidden_dim)
model.load_state_dict(torch.load('trained_model.pth'))

# Freeze all layers
for param in model.parameters():
    param.requires_grad = False

class CustomFeatureExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space, features_dim):
        super(CustomFeatureExtractor, self).__init__(observation_space, features_dim)
        self.model = model  # Use the pre-trained model as the feature extractor

        self._features_dim = features_dim

    def forward(self, observations):
        features = self.model(observations)
        return features

policy_kwargs = {
        "features_extractor_class": CustomFeatureExtractor,
        "features_extractor_kwargs": {"features_dim": 64}
    }

 model = PPO("MlpPolicy", env=envs, policy_kwargs=policy_kwargs)

The model is trained well so far with no issues and good results. Now I want to not freeze the weights, and try to train the Feature Extractor as well starting from the initial pre-trained weight. How can I do that with such a custom Feature Extractor defined as a class inside another class? My feature extractor is not the same as the one defined in the documentation, so I am not sure if it will be trained. Or will it start training if I unfreeze the layers?


Solution

  • UPDATED answer

    Because your CustomFE imports already freezer Encoder (with requires_grad = False) you have that kind of situation where all weights of CustomFE are frozen. Thus by default CustomFE is not trainable. You will need to unfreeze it manually:

    
    model = PPO("MlpPolicy", env='FrozenLake8x8', policy_kwargs=policy_kwargs)
    
    # get model feature extractor
    feature_extr: CustomFeatureExtractor = model.policy.features_extractor
    
    # convert all parameters to trainable
    for name, param in feature_extr.named_parameters():
        param.requires_grad = True
    
    # check parameters before training
    encoder = feature_extr.model.encoder
    for name, param in encoder[0].named_parameters():
        print(name, param.mean())
    
    
    # train the model
    model.learn(total_timesteps = 5)
    
    
    # check parameters after training (if mean changed parameters are training)
    feature_extr: CustomFeatureExtractor = model.policy.features_extractor
    encoder = feature_extr.model.encoder
    for name, param in encoder[0].named_parameters():
        print(name, param.mean())